From 974e6820ceac90e1b70fa3f285ce4441f44c6049 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 28 Oct 2025 16:26:51 +0000 Subject: [PATCH 01/36] first try Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 47 ++- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 66 ++-- .../kernels/scaled_mm/__init__.py | 26 +- .../quantization/kernels/scaled_mm/aiter.py | 12 +- .../quantization/kernels/scaled_mm/cutlass.py | 256 +++++++++---- .../kernels/scaled_mm/flash_infer.py | 120 ++++++ .../quantization/kernels/scaled_mm/rocm.py | 179 +++++++++ .../quantization/kernels/scaled_mm/torch.py | 343 ++++++++++++++++++ 8 files changed, 924 insertions(+), 125 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ee431c9148b8..c1108e96d213 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -10,6 +10,14 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_FP8_KERNELS, + choose_scaled_mm_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( + ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, @@ -24,7 +32,6 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity, ) @@ -72,9 +79,32 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape, + param_name_list = ["weight", "weight_scale", "input_scale"] + layer_mapping_function = lambda layer: ( + tuple(getattr(layer, param_name) for param_name in param_name_list), + param_name_list, + ) + + # TODO: clean up + if self.strategy == QuantizationStrategy.TENSOR: + weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR + elif self.strategy == QuantizationStrategy.CHANNEL: + weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), + is_static_input_scheme=self.is_static_input_scheme, + input_symmetric=True, + weight_quant_strategy=weight_quant_strategy, + activation_group_shape=self.act_q_group_shape, + out_dtype=self.out_dtype, + ) + kernel = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_FP8_KERNELS, + ) + self.fp8_linear = kernel( + scaled_mm_linear_kernel_config, layer_mapping_function ) @classmethod @@ -190,11 +220,4 @@ def apply_weights( bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 2a885ec89945..0445223526c9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,16 +2,36 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass +from enum import Enum import torch +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape + + +class ScaledMMLinearQuantStrategy(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + BLOCK = "block" + + def is_per_token(self) -> bool: + return self.row == 1 and self.col == -1 + + def is_per_group(self) -> bool: + return self.row == 1 and self.col >= 1 + @dataclass class ScaledMMLinearLayerConfig: + # TODO: remove is channelwise is_channelwise: bool is_static_input_scheme: bool input_symmetric: bool + out_dtype: torch.dtype | None + weight_quant_strategy: ScaledMMLinearQuantStrategy + activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR class ScaledMMLinearKernel(ABC): @@ -26,21 +46,11 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: raise NotImplementedError def __init__( - self, - c: ScaledMMLinearLayerConfig, - w_q_param_name: str, - w_s_param_name: str, - i_s_param_name: str, - i_zp_param_name: str, - azp_adj_param_name: str, + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: assert self.can_implement(c) self.config = c - self.w_q_name = w_q_param_name - self.w_s_name = w_s_param_name - self.i_s_name = i_s_param_name - self.i_zp_name = i_zp_param_name - self.azp_adj_name = azp_adj_param_name + self.layer_mapping_function = layer_mapping_function @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -55,19 +65,19 @@ def apply_weights( ) -> torch.Tensor: raise NotImplementedError - def _get_weight_params( - self, layer: torch.nn.Module - ) -> tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - torch.Tensor | None, # input_scale, - torch.Tensor | None, # input_zp - torch.Tensor | None, # azp_adj - ]: - return ( - getattr(layer, self.w_q_name), - getattr(layer, self.w_s_name), - getattr(layer, self.i_s_name), - getattr(layer, self.i_zp_name), - getattr(layer, self.azp_adj_name), - ) + # def _get_weight_params( + # self, layer: torch.nn.Module + # ) -> tuple[ + # torch.Tensor, # weight + # torch.Tensor, # weight_scale + # torch.Tensor | None, # input_scale, + # torch.Tensor | None, # input_zp + # torch.Tensor | None, # azp_adj + # ]: + # return ( + # getattr(layer, self.w_q_name), + # getattr(layer, self.w_s_name), + # getattr(layer, self.i_s_name), + # getattr(layer, self.i_zp_name), + # getattr(layer, self.azp_adj_name), + # ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index dd59e5d935dc..2ad21162995f 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -12,10 +12,18 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( + ChannelWiseTorchScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonScaledMMLinearKernel, ) @@ -25,16 +33,28 @@ from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { +_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } +_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + PlatformEnum.ROCM: [ + ROCmScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], +} + def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, compute_capability: int | None = None + config: ScaledMMLinearLayerConfig, + possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]], + compute_capability: int | None = None, ) -> type[ScaledMMLinearKernel]: """ Choose an ScaledMMLinearKernel that can implement the given config for the @@ -61,7 +81,7 @@ def choose_scaled_mm_linear_kernel( compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] - for kernel in _POSSIBLE_KERNELS[current_platform._enum]: + for kernel in possible_kernels[current_platform._enum]: if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): failure_reasons.append( f" {kernel.__name__} disabled by environment variable" diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a19396a162bc..7dc1a57f1ecd 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -9,8 +9,8 @@ from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op -from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +from .cutlass import process_weights_after_loading +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig def rocm_aiter_gemm_w8a8_impl( @@ -52,7 +52,7 @@ def rocm_aiter_gemm_w8a8_fake( ) -class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): +class AiterScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: return 90 @@ -92,7 +92,9 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) + _, param_names = self.layer_mapping_function(layer) + + process_weights_after_loading(self.config, layer, *param_names) def apply_weights( self, @@ -110,7 +112,7 @@ def apply_weights( w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support ATIER block scaled GEMM and mix-precision GEMM. """ - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index e8769916b4ce..6e88d65acd45 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -2,10 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, ) @@ -14,6 +18,111 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +def cutlass_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + return output.view(*output_shape) + + +def process_weights_after_loading( + config: ScaledMMLinearLayerConfig, + layer: torch.nn.Module, + w_q_name: str, + w_s_name: str, + i_s_name: str, + i_zp_name: str, + azp_adj_name: str, +): + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, w_q_name) + replace_parameter( + layer, + w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, w_s_name) + if is_fused_module and not config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) + replace_parameter( + layer, + w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) + + # INPUT SCALE + if config.is_static_input_scheme: + input_scale = getattr(layer, i_s_name) + + if config.input_symmetric: + replace_parameter( + layer, + i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) + setattr(layer, i_zp_name, None) + else: + input_zero_point = getattr(layer, i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) + replace_parameter( + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) + ) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) + replace_parameter( + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + ) + + else: + setattr(layer, i_s_name, None) + setattr(layer, i_zp_name, None) + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md + # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md + if not config.input_symmetric: + weight = getattr(layer, w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, i_zp_name) * azp_adj + setattr( + layer, + azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False), + ) + else: + setattr(layer, azp_adj_name, None) + + class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -27,83 +136,9 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # WEIGHT - # Cutlass kernels need transposed weight. - weight = getattr(layer, self.w_q_name) - replace_parameter( - layer, - self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False), - ) - - # WEIGHT SCALE - # Cutlass kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) - if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) - replace_parameter( - layer, - self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False), - ) + _, param_names = self.layer_mapping_function(layer) - # INPUT SCALE - if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) - - if self.config.input_symmetric: - replace_parameter( - layer, - self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False), - ) - setattr(layer, self.i_zp_name, None) - else: - input_zero_point = getattr(layer, self.i_zp_name) - - # reconstruct the ranges - int8_traits = torch.iinfo(torch.int8) - azps = input_zero_point.to(dtype=torch.int32) - range_max = (input_scale * (int8_traits.max - azps)).max() - range_min = (input_scale * (int8_traits.min - azps)).min() - - scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) - replace_parameter( - layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) - ) - - # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) - replace_parameter( - layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) - ) - - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - - # azp_adj is the AZP adjustment term, used to account for weights. - # It does not depend on scales or azp, so it is the same for - # static and dynamic quantization. - # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md - # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md - if not self.config.input_symmetric: - weight = getattr(layer, self.w_q_name) - azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if self.config.is_static_input_scheme: - # cutlass_w8a8 requires azp to be folded into azp_adj - # in the per-tensor case - azp_adj = getattr(layer, self.i_zp_name) * azp_adj - setattr( - layer, - self.azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False), - ) - else: - setattr(layer, self.azp_adj_name, None) + process_weights_after_loading(self.config, layer, *param_names) def apply_weights( self, @@ -111,7 +146,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -138,3 +173,70 @@ def apply_weights( return ops.cutlass_scaled_mm( x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias ) + + +class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): + def __init__( + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + ) -> None: + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=GroupShape.PER_TENSOR, + num_token_padding=None, + ) + super().__init__(c, layer_mapping_function) + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return ( + False, + "CutlassFP8ScaledMMLinearKernel is supported " + + "on CUDA platforms Only.", + ) + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + + return cutlass_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py new file mode 100644 index 000000000000..9940ef49bb3e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer + +from .ScaledMMLinearKernel import ( + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) + + +def flashinfer_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return flashinfer_scaled_fp8_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + + +class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): + def __init__( + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + ) -> None: + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=GroupShape.PER_TENSOR, + num_token_padding=None, + ) + super().__init__(c, layer_mapping_function) + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if not current_platform.is_cuda(): + return ( + False, + "FlashInferScaledMMLinearKernel is supported " + + "on CUDA platforms Only.", + ) + + if not has_flashinfer(): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "FlashInfer to be installed.", + ) + if not has_flashinfer(): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "FlashInfer to be installed.", + ) + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + + return flashinfer_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py new file mode 100644 index 000000000000..74454743fb0d --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op + +from .ScaledMMLinearKernel import ( + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) + + +def rocm_per_tensor_float_w8a8_scaled_mm_impl( + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + if ( + A.shape[0] == 1 + and B.shape[1] % 16 == 0 + and ((bias is None) or (bias.dtype == out_dtype)) + ): + output = ops.wvSplitKQ( + B.t(), + A, + out_dtype, + As, + Bs, + current_platform.get_cu_count(), + bias, + ) + # Fallabck + else: + output = torch._scaled_mm( + A, + B, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs, + bias=bias, + ) + return output + + +def rocm_per_tensor_float_w8a8_scaled_mm_fake( + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype) + + +def rocm_per_tensor_float_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list[int], +) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + A, B, out_dtype, As, Bs, bias + ) + return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, + fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, + ) + + +class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): + def __init__( + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + ) -> None: + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=GroupShape.PER_TENSOR, + num_token_padding=None, + ) + super().__init__(c, layer_mapping_function) + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + # TODO: check if this causes an issue on non-ROCM platforms + from vllm.platforms.rocm import on_mi3xx + + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if not current_platform.is_rocm(): + return ( + False, + "ROCmScaledMMLinearFP8Kernel is supported " + "on ROCm platforms Only.", + ) + if not on_mi3xx(): + return ( + False, + "ROCmScaledMMLinearFP8Kernel is supported " + + "on MI3xx architures only.", + ) + if not envs.VLLM_ROCM_USE_SKINNY_GEMM: + return ( + False, + "VLLM_ROCM_USE_SKINNY_GEMM must be enabled " + + "to use ROCmScaledMMLinearKernel ", + ) + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "ROCmScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + + return rocm_per_tensor_float_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py new file mode 100644 index 000000000000..0b2c0a8b49fd --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch +from packaging import version + +from vllm.config import CompilationMode, get_current_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import ( + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) + +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = None + + +def maybe_create_device_identity(): + # Allocate dummy ones tensor for torch._scaled_mm + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) + + +def torch_per_tensor_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + output = torch._scaled_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) + + +def torch_row_wise_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM + # when using it. + # For now it has only been validated on ROCm platform. + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # + # For CUDA platform please validate if the torch._scaled_mm supports + # rowwise scaled GEMM before using it + + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm( + A, + B, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs.t(), + bias=bias, + ) + + output = torch.narrow(output, 0, 0, A.shape[0]) + output = output.view(*output_shape) + return output + + +def torch_channelwise_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Use unfused DQ due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm( + A, + B, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, A.shape[0]) + x_scale = torch.narrow(As, 0, 0, A.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * Bs.t() + if bias is not None: + output = output + bias + return output.to(out_dtype).view(*output_shape) + + +class TorchScaledMMLinearKernel(ScaledMMLinearKernel): + def __init__( + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + ) -> None: + vllm_config = get_current_vllm_config().compilation_config + pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE + + output_padding = 17 if pad_output else None + + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=GroupShape.PER_TENSOR, + num_token_padding=output_padding, + ) + super().__init__(c, layer_mapping_function) + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + return + + +class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + assert c.activation_group_shape is not None + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "PerTensorTorchScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + return torch_per_tensor_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) + + +class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def get_min_capability(cls) -> int: + return 94 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + assert c.activation_group_shape is not None + + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if per_tensor_activation_scales and per_tensor_weight_scales: + return ( + False, + "RowWiseTorchScaledMMLinearKernel cannot be used with " + + "per tensor activation and weight scales.", + ) + + if not current_platform.is_rocm(): + return ( + False, + "RowWiseTorchScaledMMLinearKernel is only supported " + + "in ROCm platforms.", + ) + + if not version.parse(torch.__version__) >= version.parse("2.7"): + return ( + False, + "RowWiseTorchScaledMMLinearKernel requires " + "pytorch version >=2.7.", + ) + + return True, None + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + return torch_row_wise_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) + + +class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def get_min_capability(cls) -> int: + return 94 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + assert c.activation_group_shape is not None + + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if per_tensor_activation_scales and per_tensor_weight_scales: + return ( + False, + "ChannelWiseTorchScaledMMLinearKernel cannot be used with " + + "per tensor activation and weight scales.", + ) + + return True, None + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + return torch_channelwise_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) From e54e5720854debbb87e1165ca9fd6355c4f7c938 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 30 Oct 2025 08:04:24 +0000 Subject: [PATCH 02/36] fix int8 path Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 7 +--- .../schemes/compressed_tensors_w8a8_int8.py | 32 +++++++++----- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 42 ++++++++----------- .../kernels/scaled_mm/__init__.py | 1 + .../quantization/kernels/scaled_mm/aiter.py | 4 +- .../quantization/kernels/scaled_mm/cpu.py | 4 +- .../quantization/kernels/scaled_mm/cutlass.py | 15 +++---- .../kernels/scaled_mm/flash_infer.py | 6 +-- .../quantization/kernels/scaled_mm/rocm.py | 6 +-- .../quantization/kernels/scaled_mm/torch.py | 10 ++--- .../quantization/kernels/scaled_mm/triton.py | 4 +- .../quantization/kernels/scaled_mm/xla.py | 4 +- .../quark/schemes/quark_w8a8_int8.py | 4 +- 13 files changed, 67 insertions(+), 72 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index c1108e96d213..bd9a6bd0ef04 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -15,7 +15,7 @@ choose_scaled_mm_linear_kernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( - ScaledMMLinearLayerConfig, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -91,10 +91,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) elif self.strategy == QuantizationStrategy.CHANNEL: weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( - is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), - is_static_input_scheme=self.is_static_input_scheme, - input_symmetric=True, + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( weight_quant_strategy=weight_quant_strategy, activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6fd0a6a1c822..049f96f1faa3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -11,8 +11,8 @@ CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel, + _POSSIBLE_INT8_KERNELS ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -20,6 +20,7 @@ ModelWeightParameter, PerTensorScaleParameter, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig logger = init_logger(__name__) @@ -50,13 +51,16 @@ def create_weights( ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig( is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, input_symmetric=self.input_symmetric, ) - kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_INT8_KERNELS + ) if kernel_type.__name__ not in self._kernel_backends_being_used: logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) @@ -90,12 +94,12 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE + input_zero_point=None + input_scale=None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader ) - layer.register_parameter("input_scale", input_scale) - if not self.input_symmetric: # Note: compressed-tensors stores the zp using the same dtype # as the weights @@ -103,15 +107,21 @@ def create_weights( input_zero_point = BasevLLMParameter( data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader ) - layer.register_parameter("input_zero_point", input_zero_point) + layer.register_parameter("input_zero_point", input_zero_point) + layer.register_parameter("input_scale", input_scale) + if not hasattr(layer, "azp_adj"): + layer.register_parameter("azp_adj", None) + + param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] + + layer_mapping_function = lambda layer: ( + tuple(getattr(layer, param_name) for param_name in param_name_list), + param_name_list, + ) self.kernel = kernel_type( c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj", + layer_mapping_function = layer_mapping_function ) # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 0445223526c9..e12aa2c5c4d2 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -5,7 +5,7 @@ from collections.abc import Callable from dataclasses import dataclass from enum import Enum - +from typing import Generic, TypeVar import torch from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -25,16 +25,25 @@ def is_per_group(self) -> bool: @dataclass class ScaledMMLinearLayerConfig: - # TODO: remove is channelwise + pass + +@dataclass +class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): is_channelwise: bool is_static_input_scheme: bool input_symmetric: bool - out_dtype: torch.dtype | None + +@dataclass +class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): weight_quant_strategy: ScaledMMLinearQuantStrategy - activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR + activation_group_shape: GroupShape + out_dtype: torch.dtype -class ScaledMMLinearKernel(ABC): +ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig) + + +class ScaledMMLinearKernel(Generic[ConfigT], ABC): @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -42,11 +51,11 @@ def get_min_capability(cls) -> int: @classmethod @abstractmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]: raise NotImplementedError def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: ConfigT, layer_mapping_function: Callable ) -> None: assert self.can_implement(c) self.config = c @@ -63,21 +72,4 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - raise NotImplementedError - - # def _get_weight_params( - # self, layer: torch.nn.Module - # ) -> tuple[ - # torch.Tensor, # weight - # torch.Tensor, # weight_scale - # torch.Tensor | None, # input_scale, - # torch.Tensor | None, # input_zp - # torch.Tensor | None, # azp_adj - # ]: - # return ( - # getattr(layer, self.w_q_name), - # getattr(layer, self.w_s_name), - # getattr(layer, self.i_s_name), - # getattr(layer, self.i_zp_name), - # getattr(layer, self.azp_adj_name), - # ) + raise NotImplementedError \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 2ad21162995f..85aaf51ae844 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -19,6 +19,7 @@ ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ) + from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 7dc1a57f1ecd..a39e96bca614 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -10,7 +10,7 @@ from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import process_weights_after_loading -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig def rocm_aiter_gemm_w8a8_impl( @@ -58,7 +58,7 @@ def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_rocm(): return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index feb1e0bee1aa..9c8ece8559b4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -14,7 +14,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig class CPUScaledMMLinearKernel(ScaledMMLinearKernel): @@ -23,7 +23,7 @@ def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cpu(): return False, "CPUScaledMM requires running on CPU." diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 6e88d65acd45..b81d67068693 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -15,7 +15,7 @@ ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig def cutlass_w8a8_scaled_mm( @@ -36,7 +36,7 @@ def cutlass_w8a8_scaled_mm( def process_weights_after_loading( - config: ScaledMMLinearLayerConfig, + config: Int8ScaledMMLinearLayerConfig, layer: torch.nn.Module, w_q_name: str, w_s_name: str, @@ -98,9 +98,6 @@ def process_weights_after_loading( layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - else: - setattr(layer, i_s_name, None) - setattr(layer, i_zp_name, None) # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for @@ -119,8 +116,6 @@ def process_weights_after_loading( azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False), ) - else: - setattr(layer, azp_adj_name, None) class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): @@ -129,7 +124,7 @@ def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cuda(): return False, "CutlassScaledMM requires running on CUDA." @@ -177,7 +172,7 @@ def apply_weights( class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: self.quant_fp8 = QuantFP8( static=c.is_static_input_scheme, @@ -192,7 +187,7 @@ def get_min_capability(cls) -> int: return 89 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cuda(): return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 9940ef49bb3e..9fcbb2ff8ec8 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -11,7 +11,7 @@ from .ScaledMMLinearKernel import ( ScaledMMLinearKernel, - ScaledMMLinearLayerConfig, + Int8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) @@ -32,7 +32,7 @@ def flashinfer_w8a8_scaled_mm( class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: self.quant_fp8 = QuantFP8( static=c.is_static_input_scheme, @@ -46,7 +46,7 @@ def get_min_capability(cls) -> int: return 100 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 74454743fb0d..17b932f2336d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -13,7 +13,7 @@ from .ScaledMMLinearKernel import ( ScaledMMLinearKernel, - ScaledMMLinearLayerConfig, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) @@ -90,7 +90,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm( class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: self.quant_fp8 = QuantFP8( static=c.is_static_input_scheme, @@ -104,7 +104,7 @@ def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: # TODO: check if this causes an issue on non-ROCM platforms from vllm.platforms.rocm import on_mi3xx diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py index 0b2c0a8b49fd..7d82496dca02 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py @@ -12,7 +12,7 @@ from .ScaledMMLinearKernel import ( ScaledMMLinearKernel, - ScaledMMLinearLayerConfig, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) @@ -136,7 +136,7 @@ def torch_channelwise_w8a8_scaled_mm( class TorchScaledMMLinearKernel(ScaledMMLinearKernel): def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: vllm_config = get_current_vllm_config().compilation_config pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE @@ -161,7 +161,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: assert c.activation_group_shape is not None per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( @@ -218,7 +218,7 @@ def get_min_capability(cls) -> int: return 94 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: assert c.activation_group_shape is not None per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() @@ -290,7 +290,7 @@ def get_min_capability(cls) -> int: return 94 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: assert c.activation_group_shape is not None per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index 3f4ec7f2a738..0c8ee18457dd 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -7,7 +7,7 @@ from vllm.platforms import current_platform from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): @@ -16,7 +16,7 @@ def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if current_platform.is_cpu(): return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index ddac9f13cf4f..6150270c8773 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -12,7 +12,7 @@ ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig class XLAScaledMMLinearKernel(ScaledMMLinearKernel): @@ -24,7 +24,7 @@ def get_min_capability(cls) -> int: ) @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_tpu(): return False, "ScaledMMXLA requires running on TPU." diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 42d2ed2e85ed..3d51ea2cd958 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,9 +7,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -50,7 +50,7 @@ def create_weights( ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig( is_channelwise=(self.qscheme == "per_channel"), is_static_input_scheme=(self.is_static_input_scheme is True), input_symmetric=(self.input_symmetric is True), From c05027f67a8d8cc645207163a43838ffbf90174a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 30 Oct 2025 12:27:04 +0000 Subject: [PATCH 03/36] clean up; fix quark path Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 26 +- .../schemes/compressed_tensors_w8a8_int8.py | 8 +- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 91 +++++-- .../quantization/kernels/scaled_mm/aiter.py | 13 +- .../quantization/kernels/scaled_mm/cpu.py | 66 +++-- .../quantization/kernels/scaled_mm/cutlass.py | 232 +++++++----------- .../kernels/scaled_mm/flash_infer.py | 60 ++--- .../quantization/kernels/scaled_mm/rocm.py | 64 ++--- .../quantization/kernels/scaled_mm/torch.py | 146 +++-------- .../quantization/kernels/scaled_mm/utils.py | 44 ++++ .../quantization/kernels/scaled_mm/xla.py | 23 +- .../quark/schemes/quark_w8a8_int8.py | 17 +- 12 files changed, 357 insertions(+), 433 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index bd9a6bd0ef04..53e7ed2fb3fc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -6,6 +6,7 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, @@ -17,6 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, + QUANT_STRATEGY_MAP, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -49,8 +51,11 @@ QuantizationStrategy.TENSOR: PerTensorScaleParameter, } +logger = init_logger(__name__) class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): self.weight_quant = weight_quant self.strategy = weight_quant.strategy @@ -79,19 +84,10 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - param_name_list = ["weight", "weight_scale", "input_scale"] - layer_mapping_function = lambda layer: ( - tuple(getattr(layer, param_name) for param_name in param_name_list), - param_name_list, - ) - - # TODO: clean up - if self.strategy == QuantizationStrategy.TENSOR: - weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR - elif self.strategy == QuantizationStrategy.CHANNEL: - weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL - + layer_param_names = ["weight", "weight_scale", "input_scale"] + weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + is_static_input_scheme=self.is_static_input_scheme, weight_quant_strategy=weight_quant_strategy, activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, @@ -101,9 +97,13 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) _POSSIBLE_FP8_KERNELS, ) self.fp8_linear = kernel( - scaled_mm_linear_kernel_config, layer_mapping_function + scaled_mm_linear_kernel_config, layer_param_names = layer_param_names ) + if kernel.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__) + self._kernel_backends_being_used.add(kernel.__name__) + @classmethod def get_min_capability(cls) -> int: # lovelace and up diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 049f96f1faa3..a0ae8655ca65 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -113,15 +113,11 @@ def create_weights( if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] + layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] - layer_mapping_function = lambda layer: ( - tuple(getattr(layer, param_name) for param_name in param_name_list), - param_name_list, - ) self.kernel = kernel_type( c=scaled_mm_linear_kernel_config, - layer_mapping_function = layer_mapping_function + layer_param_names = layer_param_names ) # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index e12aa2c5c4d2..27af30ae131c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -5,10 +5,12 @@ from collections.abc import Callable from dataclasses import dataclass from enum import Enum -from typing import Generic, TypeVar +from typing import Generic, Sequence, TypeVar import torch +from compressed_tensors.quantization import QuantizationStrategy from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 class ScaledMMLinearQuantStrategy(Enum): @@ -16,21 +18,19 @@ class ScaledMMLinearQuantStrategy(Enum): CHANNEL = "channel" BLOCK = "block" - def is_per_token(self) -> bool: - return self.row == 1 and self.col == -1 - - def is_per_group(self) -> bool: - return self.row == 1 and self.col >= 1 - +QUANT_STRATEGY_MAP = { + QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR, + QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL, + QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK, +} @dataclass class ScaledMMLinearLayerConfig: - pass + is_static_input_scheme: bool @dataclass class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): is_channelwise: bool - is_static_input_scheme: bool input_symmetric: bool @dataclass @@ -40,10 +40,24 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): out_dtype: torch.dtype -ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig) +Int8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, +] +FP8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_zp + torch.Tensor | None, # azp_adj + ] -class ScaledMMLinearKernel(Generic[ConfigT], ABC): +ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT) +ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig) + +class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -55,11 +69,11 @@ def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]: raise NotImplementedError def __init__( - self, c: ConfigT, layer_mapping_function: Callable + self, c: ConfigT, layer_param_names: Sequence[str] ) -> None: assert self.can_implement(c) self.config = c - self.layer_mapping_function = layer_mapping_function + self.layer_param_names = layer_param_names @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -72,4 +86,53 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + # return a covariant type in the subclass + @abstractmethod + def _get_layer_params(self, layer) -> ParamsT: + raise NotImplementedError + + +class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC): + def __init__( + self, c: ConfigT, layer_param_names: Sequence[str] + ) -> None: + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=c.activation_group_shape, + num_token_padding=self.get_ouput_padding(), + ) + super().__init__(c, layer_param_names) + + @abstractmethod + def get_ouput_padding(self) -> int | None: + raise NotImplementedError + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def _get_layer_params(self, layer) -> FP8ParamsT: + w, w_s, x_s = self.layer_param_names + return ( + getattr(layer, w), + getattr(layer, w_s), + getattr(layer, x_s), + ) + + +class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC): + def _get_layer_params(self, layer) -> Int8ParamsT: + w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names + return ( + getattr(layer, w_q), + getattr(layer, w_s), + getattr(layer, i_s), + getattr(layer, i_zp), + getattr(layer, azp_adj), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a39e96bca614..3ac90553bbc7 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -9,8 +9,8 @@ from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op -from .cutlass import process_weights_after_loading -from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig def rocm_aiter_gemm_w8a8_impl( @@ -52,7 +52,7 @@ def rocm_aiter_gemm_w8a8_fake( ) -class AiterScaledMMLinearKernel(ScaledMMLinearKernel): +class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: return 90 @@ -91,11 +91,6 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No ) return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - _, param_names = self.layer_mapping_function(layer) - - process_weights_after_loading(self.config, layer, *param_names) - def apply_weights( self, layer: torch.nn.Module, @@ -112,7 +107,7 @@ def apply_weights( w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support ATIER block scaled GEMM and mix-precision GEMM. """ - (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index 9c8ece8559b4..b84ef7814f0a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -14,10 +14,10 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig -class CPUScaledMMLinearKernel(ScaledMMLinearKernel): +class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: return 75 @@ -30,7 +30,8 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight = getattr(layer, self.w_q_name) + w_q_name, _, _, _, _ = self.layer_param_names + weight = getattr(layer, w_q_name) dtype = weight.dtype N, K = weight.size() if ( @@ -48,10 +49,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # WEIGHT # Transpose to [K, N] for convenience - weight = getattr(layer, self.w_q_name) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( + self.layer_param_names + ) + weight = getattr(layer, w_q_name) replace_parameter( layer, - self.w_q_name, + w_q_name, torch.nn.Parameter(weight.t().data, requires_grad=False), ) @@ -60,28 +64,27 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if is_fused_module and not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # INPUT SCALE if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) + input_scale = getattr(layer, i_s_name) if self.config.input_symmetric: replace_parameter( layer, - self.i_s_name, + i_s_name, torch.nn.Parameter(input_scale.max(), requires_grad=False), ) - setattr(layer, self.i_zp_name, None) else: - input_zero_point = getattr(layer, self.i_zp_name) + input_zero_point = getattr(layer, i_zp_name) # reconstruct the ranges int8_traits = torch.iinfo(torch.int8) @@ -91,20 +94,16 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) ) azp = ( (int8_traits.min - range_min / scale).round().to(dtype=torch.int32) ) replace_parameter( - layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - # Different from cutlass, oneDNN kernels only need the AZP adjustment # term for dynamic quantization. And s_b should be folded into the # term. Such as: @@ -112,38 +111,37 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = # s_a * GEMM_output - s_a * zp_a * adj + bias if not (self.config.input_symmetric and self.config.is_static_input_scheme): - weight = getattr(layer, self.w_q_name) - weight_scale = getattr(layer, self.w_s_name) + weight = getattr(layer, w_q_name) + weight_scale = getattr(layer, w_s_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) azp_adj = azp_adj * weight_scale.squeeze() setattr( layer, - self.azp_adj_name, + azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False), ) - else: - setattr(layer, self.azp_adj_name, None) - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) self.dnnl_handler = ops.create_onednn_scaled_mm( weight, - getattr(layer, self.w_s_name), + getattr(layer, w_s_name), torch.get_default_dtype(), - getattr(layer, self.i_s_name) is None, + getattr(layer, i_s_name) is None, not self.config.input_symmetric, 32, ) # weight is prepacked and maintained by the dnnl_handler, # release the original weight - setattr(layer, self.w_q_name, None) + setattr(layer, w_q_name, None) del weight def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: + w_q_name, w_s_name, _, _, _ = self.layer_param_names # WEIGHT - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) packed_weight = torch.ops._C.convert_weight_packed(weight) replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) + layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) ) if layer.bias is not None: @@ -155,19 +153,15 @@ def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: # WEIGHT SCALE # CPU SGL kernels only support per-channel. # For per-tensor quant, convert to the per-channel case. - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) - def apply_weights( self, layer: torch.nn.Module, @@ -186,7 +180,7 @@ def _apply_weights_onednn( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -208,7 +202,7 @@ def _apply_weights_sgl( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) + w_q, w_s, _, _, _ = self._get_layer_params(layer) return torch.ops._C.int8_scaled_mm_with_quant( x, w_q, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index b81d67068693..2a8b68980949 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -15,10 +15,10 @@ ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel +from .utils import apply_weights_fp8 - -def cutlass_w8a8_scaled_mm( +def cutlass_w8a8_scaled_mm_fp8( *, A: torch.Tensor, B: torch.Tensor, @@ -34,91 +34,7 @@ def cutlass_w8a8_scaled_mm( ) return output.view(*output_shape) - -def process_weights_after_loading( - config: Int8ScaledMMLinearLayerConfig, - layer: torch.nn.Module, - w_q_name: str, - w_s_name: str, - i_s_name: str, - i_zp_name: str, - azp_adj_name: str, -): - # WEIGHT - # Cutlass kernels need transposed weight. - weight = getattr(layer, w_q_name) - replace_parameter( - layer, - w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False), - ) - - # WEIGHT SCALE - # Cutlass kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, w_s_name) - if is_fused_module and not config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) - replace_parameter( - layer, - w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False), - ) - - # INPUT SCALE - if config.is_static_input_scheme: - input_scale = getattr(layer, i_s_name) - - if config.input_symmetric: - replace_parameter( - layer, - i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False), - ) - setattr(layer, i_zp_name, None) - else: - input_zero_point = getattr(layer, i_zp_name) - - # reconstruct the ranges - int8_traits = torch.iinfo(torch.int8) - azps = input_zero_point.to(dtype=torch.int32) - range_max = (input_scale * (int8_traits.max - azps)).max() - range_min = (input_scale * (int8_traits.min - azps)).min() - - scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) - replace_parameter( - layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) - ) - - # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) - replace_parameter( - layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) - ) - - - # azp_adj is the AZP adjustment term, used to account for weights. - # It does not depend on scales or azp, so it is the same for - # static and dynamic quantization. - # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md - # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md - if not config.input_symmetric: - weight = getattr(layer, w_q_name) - azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if config.is_static_input_scheme: - # cutlass_w8a8 requires azp to be folded into azp_adj - # in the per-tensor case - azp_adj = getattr(layer, i_zp_name) * azp_adj - setattr( - layer, - azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False), - ) - - -class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): +class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: return 75 @@ -131,9 +47,83 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - _, param_names = self.layer_mapping_function(layer) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( + self.layer_param_names + ) + config = self.config + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, w_q_name) + replace_parameter( + layer, + w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, w_s_name) + if is_fused_module and not config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) + replace_parameter( + layer, + w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) + + # INPUT SCALE + if config.is_static_input_scheme: + input_scale = getattr(layer, i_s_name) + + if config.input_symmetric: + replace_parameter( + layer, + i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) + setattr(layer, i_zp_name, None) + else: + input_zero_point = getattr(layer, i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) + replace_parameter( + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) + ) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) + replace_parameter( + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + ) + + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md + # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md + if not config.input_symmetric: + weight = getattr(layer, w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, i_zp_name) * azp_adj + setattr( + layer, + azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False), + ) - process_weights_after_loading(self.config, layer, *param_names) def apply_weights( self, @@ -141,7 +131,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -170,21 +160,10 @@ def apply_weights( ) -class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): - def __init__( - self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable - ) -> None: - self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=GroupShape.PER_TENSOR, - num_token_padding=None, - ) - super().__init__(c, layer_mapping_function) +class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): - @classmethod - def get_min_capability(cls) -> int: - # lovelace and up - return 89 + def get_ouput_padding(self) -> int | None: + return None @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: @@ -197,41 +176,20 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - - return cutlass_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, - ) + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + cutlass_w8a8_scaled_mm_fp8, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 9fcbb2ff8ec8..5cb4fa7150d4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -10,10 +10,11 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from .ScaledMMLinearKernel import ( - ScaledMMLinearKernel, + FP8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) +from .utils import apply_weights_fp8 def flashinfer_w8a8_scaled_mm( @@ -30,16 +31,10 @@ def flashinfer_w8a8_scaled_mm( ) -class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): - def __init__( - self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable - ) -> None: - self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=GroupShape.PER_TENSOR, - num_token_padding=None, - ) - super().__init__(c, layer_mapping_function) +class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): + + def get_ouput_padding(self) -> int | None: + return None @classmethod def get_min_capability(cls) -> int: @@ -80,41 +75,20 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No ) return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - - return flashinfer_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, - ) + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + flashinfer_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 17b932f2336d..8abe124c4b6f 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -12,11 +12,11 @@ from vllm.utils.torch_utils import direct_register_custom_op from .ScaledMMLinearKernel import ( - ScaledMMLinearKernel, + FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) - +from .utils import apply_weights_fp8 def rocm_per_tensor_float_w8a8_scaled_mm_impl( A: torch.Tensor, @@ -88,20 +88,9 @@ def rocm_per_tensor_float_w8a8_scaled_mm( ) -class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): - def __init__( - self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable - ) -> None: - self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=GroupShape.PER_TENSOR, - num_token_padding=None, - ) - super().__init__(c, layer_mapping_function) - - @classmethod - def get_min_capability(cls) -> int: - return 90 +class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): + def get_ouput_padding(self) -> int | None: + return None @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: @@ -128,7 +117,7 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non return ( False, "VLLM_ROCM_USE_SKINNY_GEMM must be enabled " - + "to use ROCmScaledMMLinearKernel ", + + "to use ROCmScaledMMLinearKernel.", ) if not (per_tensor_activation_scales and per_tensor_weight_scales): @@ -139,41 +128,20 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non ) return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - - return rocm_per_tensor_float_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + rocm_per_tensor_float_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py index 7d82496dca02..8e5fc66e4fed 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py @@ -11,11 +11,12 @@ from vllm.platforms import current_platform from .ScaledMMLinearKernel import ( - ScaledMMLinearKernel, + FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) +from .utils import apply_weights_fp8 # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = None @@ -134,35 +135,16 @@ def torch_channelwise_w8a8_scaled_mm( return output.to(out_dtype).view(*output_shape) -class TorchScaledMMLinearKernel(ScaledMMLinearKernel): - def __init__( - self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable - ) -> None: +class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): + def get_ouput_padding(self) -> int | None: vllm_config = get_current_vllm_config().compilation_config pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE - output_padding = 17 if pad_output else None - - self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=GroupShape.PER_TENSOR, - num_token_padding=output_padding, - ) - super().__init__(c, layer_mapping_function) - - @classmethod - def get_min_capability(cls) -> int: - # lovelace and up - return 89 - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - return - + return output_padding class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - assert c.activation_group_shape is not None per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR @@ -182,36 +164,18 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - return torch_per_tensor_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + torch_per_tensor_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype ) - class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -219,14 +183,12 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - assert c.activation_group_shape is not None - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR ) - if per_tensor_activation_scales and per_tensor_weight_scales: + if per_tensor_activation_scales or per_tensor_weight_scales: return ( False, "RowWiseTorchScaledMMLinearKernel cannot be used with " @@ -254,33 +216,16 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - return torch_row_wise_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + torch_row_wise_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype ) @@ -291,8 +236,6 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - assert c.activation_group_shape is not None - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR @@ -313,31 +256,14 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - return torch_channelwise_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + torch_channelwise_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py new file mode 100644 index 000000000000..e1d5a291b846 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -0,0 +1,44 @@ +from collections.abc import Callable +import torch +from vllm.platforms import current_platform + +FP8ScaledMMCallBack = Callable[..., torch.Tensor] +FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]] + +def apply_weights_fp8( + scaled_mm_func: FP8ScaledMMCallBack, + quant_fp8_func: FP8QuantCallback, + w:torch.Tensor, + x:torch.Tensor, + w_s:torch.Tensor, + x_s:torch.Tensor, + bias:torch.Tensor, + maybe_out_dtype: torch.dtype | None, + ) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_s computed from x. + # If static, layer.input_scale is scalar and x_s is input_scale. + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + output_shape = [*x.shape[:-1], w.shape[1]] + + out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = quant_fp8_func( + x_2d, + x_s, + ) + + return scaled_mm_func( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 6150270c8773..bafaf06ed796 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -12,10 +12,10 @@ ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig -class XLAScaledMMLinearKernel(ScaledMMLinearKernel): +class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: raise NotImplementedError( @@ -42,9 +42,12 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # [out, in] (different than cutlass_scaled_mm) - weight = getattr(layer, self.w_q_name) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( + self.layer_param_names + ) + weight = getattr(layer, w_q_name) replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) + layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) ) # WEIGHT SCALE @@ -52,7 +55,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if is_fused_module and not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) @@ -60,14 +63,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight_scale = weight_scale.squeeze(-1) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # Only support symmetric dynamic activation quantization. - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) + setattr(layer, i_s_name, None) + setattr(layer, i_zp_name, None) + setattr(layer, azp_adj_name, None) # Filter warning for cond usage in apply_weights. It is okay # to specialize the graph since bias is not dynamic. @@ -88,7 +91,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) + w_q, w_s, _, _, _ = self._get_layer_params(layer) # Required to register custom ops. import torch_xla.experimental.custom_kernel # noqa: F401 diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 3d51ea2cd958..856d7fb32c09 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -102,24 +102,27 @@ def create_weights( layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE + input_zero_point=None + input_scale=None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader ) - layer.register_parameter("input_scale", input_scale) input_zero_point = BasevLLMParameter( data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader ) - layer.register_parameter("input_zero_point", input_zero_point) + + layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_zero_point", input_zero_point) + if not hasattr(layer, "azp_adj"): + layer.register_parameter("azp_adj", None) + + layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] self.kernel = kernel_type( c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj", + layer_param_names = layer_param_names ) # Checkpoints are serialized in quark format, which is From c089ea5753cf5dff4d26fc21f9c729dd3485ef6c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 30 Oct 2025 14:24:19 +0000 Subject: [PATCH 04/36] update quark fp8 path; format Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 27 +++++---- .../schemes/compressed_tensors_w8a8_int8.py | 24 +++++--- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 58 ++++++++++--------- .../kernels/scaled_mm/__init__.py | 1 - .../quantization/kernels/scaled_mm/cpu.py | 9 +-- .../quantization/kernels/scaled_mm/cutlass.py | 24 ++++---- .../kernels/scaled_mm/flash_infer.py | 12 ++-- .../quantization/kernels/scaled_mm/rocm.py | 8 +-- .../quantization/kernels/scaled_mm/torch.py | 15 +++-- .../quantization/kernels/scaled_mm/utils.py | 25 ++++---- .../quantization/kernels/scaled_mm/xla.py | 9 +-- .../quark/schemes/quark_w8a8_fp8.py | 54 ++++++++++++----- .../quark/schemes/quark_w8a8_int8.py | 26 ++++++--- 13 files changed, 171 insertions(+), 121 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 53e7ed2fb3fc..a872ee15c7ae 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -6,8 +6,8 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter -from vllm.logger import init_logger +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) @@ -15,10 +15,9 @@ _POSSIBLE_FP8_KERNELS, choose_scaled_mm_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( - FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 QUANT_STRATEGY_MAP, + FP8ScaledMMLinearLayerConfig, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -53,6 +52,7 @@ logger = init_logger(__name__) + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() @@ -92,17 +92,20 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, ) - kernel = choose_scaled_mm_linear_kernel( + kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, ) - self.fp8_linear = kernel( - scaled_mm_linear_kernel_config, layer_param_names = layer_param_names - ) - if kernel.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__) - self._kernel_backends_being_used.add(kernel.__name__) + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info( + "Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__ + ) + self._kernel_backends_being_used.add(kernel_type.__name__) + + self.kernel = kernel_type( + scaled_mm_linear_kernel_config, layer_param_names=layer_param_names + ) @classmethod def get_min_capability(cls) -> int: @@ -217,4 +220,4 @@ def apply_weights( bias=bias, ) - return self.fp8_linear.apply_weights(layer, x, bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index a0ae8655ca65..e662a1af7f1f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -11,8 +11,11 @@ CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_INT8_KERNELS, choose_scaled_mm_linear_kernel, - _POSSIBLE_INT8_KERNELS +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + Int8ScaledMMLinearLayerConfig, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -20,7 +23,6 @@ ModelWeightParameter, PerTensorScaleParameter, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig logger = init_logger(__name__) @@ -58,8 +60,7 @@ def create_weights( ) kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_INT8_KERNELS + scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS ) if kernel_type.__name__ not in self._kernel_backends_being_used: @@ -94,8 +95,8 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE - input_zero_point=None - input_scale=None + input_zero_point = None + input_scale = None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader @@ -113,11 +114,16 @@ def create_weights( if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] + layer_param_names = [ + "weight", + "weight_scale", + "input_scale", + "input_zero_point", + "azp_adj", + ] self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, - layer_param_names = layer_param_names + c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names ) # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 27af30ae131c..b9acd89f69d8 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,15 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Generic, Sequence, TypeVar +from typing import Generic, TypeVar + import torch from compressed_tensors.quantization import QuantizationStrategy -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape class ScaledMMLinearQuantStrategy(Enum): @@ -18,21 +19,24 @@ class ScaledMMLinearQuantStrategy(Enum): CHANNEL = "channel" BLOCK = "block" + QUANT_STRATEGY_MAP = { QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR, QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL, - QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK, } + @dataclass class ScaledMMLinearLayerConfig: is_static_input_scheme: bool + @dataclass class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): is_channelwise: bool input_symmetric: bool + @dataclass class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): weight_quant_strategy: ScaledMMLinearQuantStrategy @@ -40,22 +44,22 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): out_dtype: torch.dtype - +FP8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, +] Int8ParamsT = tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - torch.Tensor | None, # input_scale, + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_zp + torch.Tensor | None, # azp_adj ] -FP8ParamsT = tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - torch.Tensor | None, # input_scale, - torch.Tensor | None, # input_zp - torch.Tensor | None, # azp_adj - ] -ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT) -ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig) +ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT) +ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig) + class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): @classmethod @@ -68,9 +72,7 @@ def get_min_capability(cls) -> int: def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]: raise NotImplementedError - def __init__( - self, c: ConfigT, layer_param_names: Sequence[str] - ) -> None: + def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None: assert self.can_implement(c) self.config = c self.layer_param_names = layer_param_names @@ -87,16 +89,18 @@ def apply_weights( bias: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError - + # return a covariant type in the subclass @abstractmethod def _get_layer_params(self, layer) -> ParamsT: raise NotImplementedError -class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC): +class FP8ScaledMMLinearKernel( + ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC +): def __init__( - self, c: ConfigT, layer_param_names: Sequence[str] + self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] ) -> None: self.quant_fp8 = QuantFP8( static=c.is_static_input_scheme, @@ -104,7 +108,7 @@ def __init__( num_token_padding=self.get_ouput_padding(), ) super().__init__(c, layer_param_names) - + @abstractmethod def get_ouput_padding(self) -> int | None: raise NotImplementedError @@ -113,7 +117,7 @@ def get_ouput_padding(self) -> int | None: def get_min_capability(cls) -> int: # lovelace and up return 89 - + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass @@ -126,7 +130,9 @@ def _get_layer_params(self, layer) -> FP8ParamsT: ) -class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC): +class Int8ScaledMMLinearKernel( + ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC +): def _get_layer_params(self, layer) -> Int8ParamsT: w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names return ( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 85aaf51ae844..2ad21162995f 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -19,7 +19,6 @@ ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ) - from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index b84ef7814f0a..7fa47dd854af 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -14,7 +14,10 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel): @@ -49,9 +52,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # WEIGHT # Transpose to [K, N] for convenience - w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( - self.layer_param_names - ) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names weight = getattr(layer, w_q_name) replace_parameter( layer, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2a8b68980949..28348f50fc27 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -2,22 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable - import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) from .utils import apply_weights_fp8 + def cutlass_w8a8_scaled_mm_fp8( *, A: torch.Tensor, @@ -34,6 +36,7 @@ def cutlass_w8a8_scaled_mm_fp8( ) return output.view(*output_shape) + class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -47,9 +50,7 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( - self.layer_param_names - ) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names config = self.config # WEIGHT # Cutlass kernels need transposed weight. @@ -105,7 +106,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for # static and dynamic quantization. @@ -124,7 +124,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: torch.nn.Parameter(azp_adj, requires_grad=False), ) - def apply_weights( self, layer: torch.nn.Module, @@ -161,7 +160,6 @@ def apply_weights( class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: return None @@ -191,5 +189,5 @@ def apply_weights( w_s, x_s, bias, - self.config.out_dtype - ) \ No newline at end of file + self.config.out_dtype, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 5cb4fa7150d4..8fd2c88857ca 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -1,17 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, - Int8ScaledMMLinearLayerConfig, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) from .utils import apply_weights_fp8 @@ -32,7 +29,6 @@ def flashinfer_w8a8_scaled_mm( class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: return None @@ -41,7 +37,7 @@ def get_min_capability(cls) -> int: return 100 @classmethod - def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR @@ -90,5 +86,5 @@ def apply_weights( w_s, x_s, bias, - self.config.out_dtype - ) \ No newline at end of file + self.config.out_dtype, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 8abe124c4b6f..6144a94b7fb9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -18,6 +15,7 @@ ) from .utils import apply_weights_fp8 + def rocm_per_tensor_float_w8a8_scaled_mm_impl( A: torch.Tensor, B: torch.Tensor, @@ -40,7 +38,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl( current_platform.get_cu_count(), bias, ) - # Fallabck + # Fallback else: output = torch._scaled_mm( A, @@ -143,5 +141,5 @@ def apply_weights( w_s, x_s, bias, - self.config.out_dtype + self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py index 8e5fc66e4fed..c2a8474ac5b4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch from packaging import version from vllm.config import CompilationMode, get_current_vllm_config -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from .ScaledMMLinearKernel import ( @@ -15,8 +12,8 @@ FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) - from .utils import apply_weights_fp8 + # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = None @@ -142,6 +139,7 @@ def get_ouput_padding(self) -> int | None: output_padding = 17 if pad_output else None return output_padding + class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: @@ -173,9 +171,10 @@ def apply_weights( w_s, x_s, bias, - self.config.out_dtype + self.config.out_dtype, ) + class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -199,7 +198,7 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non return ( False, "RowWiseTorchScaledMMLinearKernel is only supported " - + "in ROCm platforms.", + + "on ROCm platforms.", ) if not version.parse(torch.__version__) >= version.parse("2.7"): @@ -225,7 +224,7 @@ def apply_weights( w_s, x_s, bias, - self.config.out_dtype + self.config.out_dtype, ) @@ -265,5 +264,5 @@ def apply_weights( w_s, x_s, bias, - self.config.out_dtype + self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index e1d5a291b846..9f4e9a7befc4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -1,20 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -import torch + +import torch + from vllm.platforms import current_platform FP8ScaledMMCallBack = Callable[..., torch.Tensor] FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]] + def apply_weights_fp8( - scaled_mm_func: FP8ScaledMMCallBack, - quant_fp8_func: FP8QuantCallback, - w:torch.Tensor, - x:torch.Tensor, - w_s:torch.Tensor, - x_s:torch.Tensor, - bias:torch.Tensor, - maybe_out_dtype: torch.dtype | None, - ) -> torch.Tensor: + scaled_mm_func: FP8ScaledMMCallBack, + quant_fp8_func: FP8QuantCallback, + w: torch.Tensor, + x: torch.Tensor, + w_s: torch.Tensor, + x_s: torch.Tensor, + bias: torch.Tensor, + maybe_out_dtype: torch.dtype | None, +) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_s computed from x. # If static, layer.input_scale is scalar and x_s is input_scale. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index bafaf06ed796..02ec0d931bfd 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -12,7 +12,10 @@ ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel): @@ -42,9 +45,7 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # [out, in] (different than cutlass_scaled_mm) - w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( - self.layer_param_names - ) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names weight = getattr(layer, w_q_name) replace_parameter( layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 1e5ee93b61f2..6c296fe9a580 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,10 +7,18 @@ import torch from torch.nn import Parameter +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_FP8_KERNELS, + choose_scaled_mm_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, ) @@ -23,8 +31,17 @@ __all__ = ["QuarkW8A8Fp8"] +logger = init_logger(__name__) + +QUANT_STRATEGY_MAP = { + "per_tensor": ScaledMMLinearQuantStrategy.TENSOR, + "per_channel": ScaledMMLinearQuantStrategy.CHANNEL, +} + class QuarkW8A8Fp8(QuarkScheme): + _kernel_backends_being_used: set[str] = set() + def __init__( self, weight_config: dict[str, Any], input_config: dict[str, Any] | None ): @@ -41,10 +58,6 @@ def __init__( self.act_quant_group_shape = ( GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR ) - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_quant_group_shape, - ) self.out_dtype = torch.get_default_dtype() @classmethod @@ -163,17 +176,32 @@ def create_weights( input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) + layer_param_names = ["weight", "weight_scale", "input_scale"] + weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + is_static_input_scheme=self.is_static_input_scheme, + weight_quant_strategy=weight_quant_strategy, + activation_group_shape=self.act_quant_group_shape, + out_dtype=self.out_dtype, + ) + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_FP8_KERNELS, + ) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + layer_param_names = ["weight", "weight_scale", "input_scale"] + self.kernel = kernel_type( + c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names + ) + def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 856d7fb32c09..2fb69fe5e40e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,9 +7,12 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_INT8_KERNELS, choose_scaled_mm_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + Int8ScaledMMLinearLayerConfig, +) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -56,7 +59,9 @@ def create_weights( input_symmetric=(self.input_symmetric is True), ) - kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS + ) if kernel_type.__name__ not in self._kernel_backends_being_used: logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) @@ -102,8 +107,8 @@ def create_weights( layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE - input_zero_point=None - input_scale=None + input_zero_point = None + input_scale = None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader @@ -117,12 +122,17 @@ def create_weights( layer.register_parameter("input_zero_point", input_zero_point) if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - - layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] + + layer_param_names = [ + "weight", + "weight_scale", + "input_scale", + "input_zero_point", + "azp_adj", + ] self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, - layer_param_names = layer_param_names + c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names ) # Checkpoints are serialized in quark format, which is From 423e2a625e5fbbf2e35029d092649a147a72f5af Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 14:07:09 +0000 Subject: [PATCH 05/36] reduce logging boilerplate; update fp8 path Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 10 +----- .../schemes/compressed_tensors_w8a8_int8.py | 10 ++---- .../model_executor/layers/quantization/fp8.py | 36 ++++++++++++------- .../kernels/scaled_mm/__init__.py | 7 ++++ .../quantization/kernels/scaled_mm/utils.py | 4 +-- .../quark/schemes/quark_w8a8_fp8.py | 7 +--- .../quark/schemes/quark_w8a8_int8.py | 10 ++---- 7 files changed, 41 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index a872ee15c7ae..633a41261ca9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -54,8 +54,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - _kernel_backends_being_used: set[str] = set() - def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): self.weight_quant = weight_quant self.strategy = weight_quant.strategy @@ -95,14 +93,8 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, + module_name=self.__class__.__name__, ) - - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info( - "Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__ - ) - self._kernel_backends_being_used.add(kernel_type.__name__) - self.kernel = kernel_type( scaled_mm_linear_kernel_config, layer_param_names=layer_param_names ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index e662a1af7f1f..914d0e1bd08a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -28,8 +28,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool ): @@ -60,13 +58,11 @@ def create_weights( ) kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS + scaled_mm_linear_kernel_config, + _POSSIBLE_INT8_KERNELS, + module_name=self.__class__.__name__, ) - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - # WEIGHT weight = ModelWeightParameter( data=torch.empty( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f82eccb88ce0..2bec9bf553c8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -42,6 +42,14 @@ QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_FP8_KERNELS, + choose_scaled_mm_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 + FP8ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -77,7 +85,6 @@ is_layer_skipped, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, @@ -387,9 +394,21 @@ def __init__(self, quant_config: Fp8Config): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + is_static_input_scheme=self.act_q_static, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_group_shape=self.act_q_group_shape, + out_dtype=self.out_dtype, + ) + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_FP8_KERNELS, + module_name=self.__class__.__name__, + ) + + self.fp8_linear_kernel = kernel_type( + scaled_mm_linear_kernel_config, + layer_param_names=["weight", "weight_scale", "input_scale"], ) def create_weights( @@ -674,14 +693,7 @@ def apply( bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) class Fp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 2ad21162995f..35f9034cacdb 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -3,6 +3,7 @@ import os +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel, ) @@ -32,6 +33,8 @@ ) from vllm.platforms import PlatformEnum, current_platform +logger = init_logger(__name__) + # in priority/performance order (when available) _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], @@ -54,6 +57,7 @@ def choose_scaled_mm_linear_kernel( config: ScaledMMLinearLayerConfig, possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]], + module_name: str, compute_capability: int | None = None, ) -> type[ScaledMMLinearKernel]: """ @@ -105,6 +109,9 @@ def choose_scaled_mm_linear_kernel( can_implement, failure_reason = kernel.can_implement(config) if can_implement: + logger.info_once( + "Selected %s for %s", kernel.__name__, module_name, scope="global" + ) return kernel else: failure_reasons.append( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index 9f4e9a7befc4..ca1a2c5b4f29 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -4,15 +4,15 @@ import torch +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.platforms import current_platform FP8ScaledMMCallBack = Callable[..., torch.Tensor] -FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]] def apply_weights_fp8( scaled_mm_func: FP8ScaledMMCallBack, - quant_fp8_func: FP8QuantCallback, + quant_fp8_func: QuantFP8, w: torch.Tensor, x: torch.Tensor, w_s: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 6c296fe9a580..e8145f261b9b 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -40,8 +40,6 @@ class QuarkW8A8Fp8(QuarkScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, weight_config: dict[str, Any], input_config: dict[str, Any] | None ): @@ -187,12 +185,9 @@ def create_weights( kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, + module_name=self.__class__.__name__, ) - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - layer_param_names = ["weight", "weight_scale", "input_scale"] self.kernel = kernel_type( c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 2fb69fe5e40e..ea8db2456f86 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -25,8 +25,6 @@ class QuarkW8A8Int8(QuarkScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, qscheme: str, @@ -60,13 +58,11 @@ def create_weights( ) kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS + scaled_mm_linear_kernel_config, + possible_kernels=_POSSIBLE_INT8_KERNELS, + module_name=self.__class__.__name__, ) - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - # WEIGHT weight = ModelWeightParameter( data=torch.empty( From dd001064c03c5a4dd9e179ea1886e9fb2b17d796 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 14:21:49 +0000 Subject: [PATCH 06/36] reduce kernel init boilerplate Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 17 ++-------- .../model_executor/layers/quantization/fp8.py | 15 ++------- .../kernels/scaled_mm/__init__.py | 31 +++++++++++++++++++ .../quark/schemes/quark_w8a8_fp8.py | 19 ++---------- 4 files changed, 39 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 633a41261ca9..56fee0523a87 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -12,12 +12,10 @@ CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_FP8_KERNELS, - choose_scaled_mm_linear_kernel, + init_fp8_linear_kernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 QUANT_STRATEGY_MAP, - FP8ScaledMMLinearLayerConfig, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -82,22 +80,13 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - layer_param_names = ["weight", "weight_scale", "input_scale"] weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] - scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + self.fp8_linear_kernel = init_fp8_linear_kernel( is_static_input_scheme=self.is_static_input_scheme, weight_quant_strategy=weight_quant_strategy, activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_FP8_KERNELS, - module_name=self.__class__.__name__, - ) - self.kernel = kernel_type( - scaled_mm_linear_kernel_config, layer_param_names=layer_param_names - ) @classmethod def get_min_capability(cls) -> int: @@ -212,4 +201,4 @@ def apply_weights( bias=bias, ) - return self.kernel.apply_weights(layer, x, bias) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2bec9bf553c8..0744f82ed27f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -43,8 +43,7 @@ ) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_FP8_KERNELS, - choose_scaled_mm_linear_kernel, + init_fp8_linear_kernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 FP8ScaledMMLinearLayerConfig, @@ -394,22 +393,12 @@ def __init__(self, quant_config: Fp8Config): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + self.fp8_linear_kernel = init_fp8_linear_kernel( is_static_input_scheme=self.act_q_static, weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_FP8_KERNELS, - module_name=self.__class__.__name__, - ) - - self.fp8_linear_kernel = kernel_type( - scaled_mm_linear_kernel_config, - layer_param_names=["weight", "weight_scale", "input_scale"], - ) def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 35f9034cacdb..c4cadecb3af5 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -3,6 +3,8 @@ import os +import torch + from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel, @@ -17,8 +19,11 @@ ROCmScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( ChannelWiseTorchScaledMMLinearKernel, @@ -32,6 +37,7 @@ XLAScaledMMLinearKernel, ) from vllm.platforms import PlatformEnum, current_platform +from vllm.vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape logger = init_logger(__name__) @@ -122,3 +128,28 @@ def choose_scaled_mm_linear_kernel( "Failed to find a kernel that can implement the " "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) ) + + +def init_fp8_linear_kernel( + act_q_static: bool, + act_q_group_shape: GroupShape, + out_dtype: torch.dtype, + module_name: str, +) -> FP8ScaledMMLinearKernel: + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + is_static_input_scheme=act_q_static, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_group_shape=act_q_group_shape, + out_dtype=out_dtype, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_FP8_KERNELS, + module_name=module_name, + ) + + return kernel_type( + scaled_mm_linear_kernel_config, + layer_param_names=["weight", "weight_scale", "input_scale"], + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index e8145f261b9b..f053b1c438e6 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -9,11 +9,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_FP8_KERNELS, - choose_scaled_mm_linear_kernel, + init_fp8_linear_kernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme @@ -174,24 +172,13 @@ def create_weights( input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) - layer_param_names = ["weight", "weight_scale", "input_scale"] weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] - scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + self.fp8_linear_kernel = init_fp8_linear_kernel( is_static_input_scheme=self.is_static_input_scheme, weight_quant_strategy=weight_quant_strategy, activation_group_shape=self.act_quant_group_shape, out_dtype=self.out_dtype, ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_FP8_KERNELS, - module_name=self.__class__.__name__, - ) - - layer_param_names = ["weight", "weight_scale", "input_scale"] - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names - ) def apply_weights( self, @@ -199,4 +186,4 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.kernel.apply_weights(layer, x, bias) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) From 7d361487f7372199a5a8fdf307fd2afd0161f296 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 14:52:51 +0000 Subject: [PATCH 07/36] update ptpc path; bug fixes Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 5 ++-- .../model_executor/layers/quantization/fp8.py | 5 ++-- .../kernels/scaled_mm/__init__.py | 7 ++--- .../scaled_mm/{torch.py => pytorch.py} | 0 .../layers/quantization/ptpc_fp8.py | 26 +++++++++++-------- .../quark/schemes/quark_w8a8_fp8.py | 5 ++-- 6 files changed, 28 insertions(+), 20 deletions(-) rename vllm/model_executor/layers/quantization/kernels/scaled_mm/{torch.py => pytorch.py} (100%) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 56fee0523a87..f4ec97804fd4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -82,10 +82,11 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) else: weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] self.fp8_linear_kernel = init_fp8_linear_kernel( - is_static_input_scheme=self.is_static_input_scheme, + act_q_static=self.is_static_input_scheme, + act_q_group_shape=self.act_q_group_shape, weight_quant_strategy=weight_quant_strategy, - activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, + module_name=self.__class__.__name__ ) @classmethod diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0744f82ed27f..91988a18fda7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -394,10 +394,11 @@ def __init__(self, quant_config: Fp8Config): ) else: self.fp8_linear_kernel = init_fp8_linear_kernel( - is_static_input_scheme=self.act_q_static, + act_q_static=self.act_q_static, + act_q_group_shape=self.act_q_group_shape, weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, - activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, + module_name=self.__class__.__name__ ) def create_weights( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index c4cadecb3af5..629ed790b966 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -25,7 +25,7 @@ ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, @@ -37,7 +37,7 @@ XLAScaledMMLinearKernel, ) from vllm.platforms import PlatformEnum, current_platform -from vllm.vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape logger = init_logger(__name__) @@ -133,12 +133,13 @@ def choose_scaled_mm_linear_kernel( def init_fp8_linear_kernel( act_q_static: bool, act_q_group_shape: GroupShape, + weight_quant_strategy: ScaledMMLinearQuantStrategy, out_dtype: torch.dtype, module_name: str, ) -> FP8ScaledMMLinearKernel: scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( is_static_input_scheme=act_q_static, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + weight_quant_strategy=weight_quant_strategy, activation_group_shape=act_q_group_shape, out_dtype=out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py rename to vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 26ba8e5b16bc..5352ba9c4500 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -16,11 +16,16 @@ Fp8KVCacheMethod, Fp8LinearMethod, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -98,11 +103,15 @@ def __init__(self, quant_config: PTPCFp8Config): ) super().__init__(quant_config=quant_config) # Force weight quantization - self.quant_config.is_checkpoint_fp8_serialized = False - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + self.fp8_linear_kernel = init_fp8_linear_kernel( + act_q_static=False, + act_q_group_shape=GroupShape.PER_TOKEN, + weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__ ) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) @@ -127,11 +136,6 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - input_scale_ub=None, - bias=bias, + return self.fp8_linear_kernel.apply_weights( + layer, x, bias ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index f053b1c438e6..94a90747e281 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -174,10 +174,11 @@ def create_weights( weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] self.fp8_linear_kernel = init_fp8_linear_kernel( - is_static_input_scheme=self.is_static_input_scheme, + act_q_static=self.is_static_input_scheme, + act_q_group_shape=self.act_quant_group_shape, weight_quant_strategy=weight_quant_strategy, - activation_group_shape=self.act_quant_group_shape, out_dtype=self.out_dtype, + module_name=self.__class__.__name__ ) def apply_weights( From 1f65cd56e5d9cd8dcddab6191fb5efb22a93430a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 15:06:51 +0000 Subject: [PATCH 08/36] revert input scale upper bounds Signed-off-by: vllmellm --- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 4 +++- .../layers/quantization/kernels/scaled_mm/cutlass.py | 3 ++- .../layers/quantization/kernels/scaled_mm/flash_infer.py | 3 ++- .../layers/quantization/kernels/scaled_mm/pytorch.py | 9 ++++++--- .../layers/quantization/kernels/scaled_mm/rocm.py | 3 ++- .../layers/quantization/kernels/scaled_mm/utils.py | 2 ++ 6 files changed, 17 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index b9acd89f69d8..9798f88b140a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -48,6 +48,7 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): torch.Tensor, # weight torch.Tensor, # weight_scale torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_scale_ub, ] Int8ParamsT = tuple[ torch.Tensor, # weight @@ -122,11 +123,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass def _get_layer_params(self, layer) -> FP8ParamsT: - w, w_s, x_s = self.layer_param_names + w, w_s, x_s, x_s_ub = self.layer_param_names return ( getattr(layer, w), getattr(layer, w_s), getattr(layer, x_s), + getattr(layer, x_s_ub), ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 28348f50fc27..fc8893cb7e1b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -180,7 +180,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( cutlass_w8a8_scaled_mm_fp8, self.quant_fp8, @@ -189,5 +189,6 @@ def apply_weights( w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 8fd2c88857ca..e33b30532204 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -77,7 +77,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( flashinfer_w8a8_scaled_mm, self.quant_fp8, @@ -86,5 +86,6 @@ def apply_weights( w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index c2a8474ac5b4..c0466e840fc0 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -162,7 +162,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( torch_per_tensor_w8a8_scaled_mm, self.quant_fp8, @@ -171,6 +171,7 @@ def apply_weights( w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) @@ -215,7 +216,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( torch_row_wise_w8a8_scaled_mm, self.quant_fp8, @@ -224,6 +225,7 @@ def apply_weights( w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) @@ -255,7 +257,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( torch_channelwise_w8a8_scaled_mm, self.quant_fp8, @@ -264,5 +266,6 @@ def apply_weights( w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 6144a94b7fb9..63744337a7e5 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -132,7 +132,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( rocm_per_tensor_float_w8a8_scaled_mm, self.quant_fp8, @@ -141,5 +141,6 @@ def apply_weights( w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index ca1a2c5b4f29..8323690817d6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -18,6 +18,7 @@ def apply_weights_fp8( w_s: torch.Tensor, x_s: torch.Tensor, bias: torch.Tensor, + x_s_ub: torch.Tensor | None, maybe_out_dtype: torch.dtype | None, ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. @@ -36,6 +37,7 @@ def apply_weights_fp8( x_2d_q, x_s = quant_fp8_func( x_2d, x_s, + x_s_ub, ) return scaled_mm_func( From 5fbe76bc0ad9d2c6a6d1bbe9c20faefdabc35cfc Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 15:08:19 +0000 Subject: [PATCH 09/36] format; update fbgemm path Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../layers/quantization/fbgemm_fp8.py | 24 ++++++++++++------- .../model_executor/layers/quantization/fp8.py | 2 +- .../kernels/scaled_mm/__init__.py | 14 +++++------ .../layers/quantization/ptpc_fp8.py | 7 ++---- .../quark/schemes/quark_w8a8_fp8.py | 2 +- 6 files changed, 27 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index f4ec97804fd4..1d0e36a3fc55 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -86,7 +86,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) act_q_group_shape=self.act_q_group_shape, weight_quant_strategy=weight_quant_strategy, out_dtype=self.out_dtype, - module_name=self.__class__.__name__ + module_name=self.__class__.__name__, ) @classmethod diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 6ba18e59e4d5..fb16681f03a0 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -18,6 +18,12 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, @@ -96,6 +102,14 @@ def __init__(self, quant_config: FBGEMMFp8Config): ) self.out_dtype = torch.get_default_dtype() + self.fp8_linear_kernel = init_fp8_linear_kernel( + act_q_static=False, + act_q_group_shape=GroupShape.PER_TOKEN, + weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + def create_weights( self, layer: torch.nn.Module, @@ -184,12 +198,4 @@ def apply( bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=None, - input_scale_ub=layer.input_scale_ub, - bias=bias, - ) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 91988a18fda7..484a8d7ab3af 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -398,7 +398,7 @@ def __init__(self, quant_config: Fp8Config): act_q_group_shape=self.act_q_group_shape, weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, out_dtype=self.out_dtype, - module_name=self.__class__.__name__ + module_name=self.__class__.__name__, ) def create_weights( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 629ed790b966..26baba602945 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -15,6 +15,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + ChannelWiseTorchScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( ROCmScaledMMLinearKernel, ) @@ -25,19 +30,14 @@ ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( - ChannelWiseTorchScaledMMLinearKernel, - PerTensorTorchScaledMMLinearKernel, - RowWiseTorchScaledMMLinearKernel, -) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel, ) -from vllm.platforms import PlatformEnum, current_platform from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import PlatformEnum, current_platform logger = init_logger(__name__) @@ -133,7 +133,7 @@ def choose_scaled_mm_linear_kernel( def init_fp8_linear_kernel( act_q_static: bool, act_q_group_shape: GroupShape, - weight_quant_strategy: ScaledMMLinearQuantStrategy, + weight_quant_strategy: ScaledMMLinearQuantStrategy, out_dtype: torch.dtype, module_name: str, ) -> FP8ScaledMMLinearKernel: diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 5352ba9c4500..2634bbd4bd87 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -108,10 +108,9 @@ def __init__(self, quant_config: PTPCFp8Config): act_q_group_shape=GroupShape.PER_TOKEN, weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, out_dtype=self.out_dtype, - module_name=self.__class__.__name__ + module_name=self.__class__.__name__, ) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) @@ -136,6 +135,4 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear_kernel.apply_weights( - layer, x, bias - ) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 94a90747e281..f32c14e27f68 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -178,7 +178,7 @@ def create_weights( act_q_group_shape=self.act_quant_group_shape, weight_quant_strategy=weight_quant_strategy, out_dtype=self.out_dtype, - module_name=self.__class__.__name__ + module_name=self.__class__.__name__, ) def apply_weights( From e845035f4c6c491914203c018c0ea51a564f780a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 16:38:26 +0000 Subject: [PATCH 10/36] bug fix Signed-off-by: vllmellm --- .../compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py | 2 ++ vllm/model_executor/layers/quantization/fbgemm_fp8.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 1 + .../layers/quantization/kernels/scaled_mm/__init__.py | 2 +- .../layers/quantization/quark/schemes/quark_w8a8_fp8.py | 2 ++ 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1d0e36a3fc55..58ea30edcd63 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -146,6 +146,8 @@ def create_weights( input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_scale_ub", None) + def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index fb16681f03a0..a7b8e6ddda71 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 484a8d7ab3af..48697e3849e0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -451,6 +451,7 @@ def create_weights( weight_loader=weight_loader, ) layer.register_parameter("weight", weight) + layer.register_parameter("input_scale_ub", None) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 26baba602945..3c0ee8323c55 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -152,5 +152,5 @@ def init_fp8_linear_kernel( return kernel_type( scaled_mm_linear_kernel_config, - layer_param_names=["weight", "weight_scale", "input_scale"], + layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"], ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index f32c14e27f68..6fff44900007 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -172,6 +172,8 @@ def create_weights( input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_scale_ub", None) + weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] self.fp8_linear_kernel = init_fp8_linear_kernel( act_q_static=self.is_static_input_scheme, From d92c23b446521c367b89119075f1288f05ae7177 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 1 Nov 2025 09:59:00 +0000 Subject: [PATCH 11/36] fix types; reduce boilerplate for int8 Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_int8.py | 25 +------ .../kernels/scaled_mm/ScaledMMLinearKernel.py | 24 +++--- .../kernels/scaled_mm/__init__.py | 75 +++++++++++++++---- .../quark/schemes/quark_w8a8_int8.py | 25 +------ 4 files changed, 77 insertions(+), 72 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 914d0e1bd08a..652feb196457 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -11,11 +11,7 @@ CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_INT8_KERNELS, - choose_scaled_mm_linear_kernel, -) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - Int8ScaledMMLinearLayerConfig, + init_int8_linear_kernel, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -51,15 +47,10 @@ def create_weights( ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig( + self.kernel = init_int8_linear_kernel( is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, input_symmetric=self.input_symmetric, - ) - - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_INT8_KERNELS, module_name=self.__class__.__name__, ) @@ -110,18 +101,6 @@ def create_weights( if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - layer_param_names = [ - "weight", - "weight_scale", - "input_scale", - "input_zero_point", - "azp_adj", - ] - - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names - ) - # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 9798f88b140a..329078f0a489 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -44,13 +44,13 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): out_dtype: torch.dtype -FP8ParamsT = tuple[ +_FP8ParamsT = tuple[ torch.Tensor, # weight torch.Tensor, # weight_scale torch.Tensor | None, # input_scale, torch.Tensor | None, # input_scale_ub, ] -Int8ParamsT = tuple[ +_Int8ParamsT = tuple[ torch.Tensor, # weight torch.Tensor, # weight_scale torch.Tensor | None, # input_scale, @@ -58,11 +58,11 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): torch.Tensor | None, # azp_adj ] -ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT) -ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig) +_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT) +_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig) -class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): +class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC): @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -70,10 +70,10 @@ def get_min_capability(cls) -> int: @classmethod @abstractmethod - def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]: + def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]: raise NotImplementedError - def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None: + def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None: assert self.can_implement(c) self.config = c self.layer_param_names = layer_param_names @@ -93,12 +93,12 @@ def apply_weights( # return a covariant type in the subclass @abstractmethod - def _get_layer_params(self, layer) -> ParamsT: + def _get_layer_params(self, layer) -> _ParamsT: raise NotImplementedError class FP8ScaledMMLinearKernel( - ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC + ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC ): def __init__( self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] @@ -122,7 +122,7 @@ def get_min_capability(cls) -> int: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass - def _get_layer_params(self, layer) -> FP8ParamsT: + def _get_layer_params(self, layer) -> _FP8ParamsT: w, w_s, x_s, x_s_ub = self.layer_param_names return ( getattr(layer, w), @@ -133,9 +133,9 @@ def _get_layer_params(self, layer) -> FP8ParamsT: class Int8ScaledMMLinearKernel( - ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC + ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC ): - def _get_layer_params(self, layer) -> Int8ParamsT: + def _get_layer_params(self, layer) -> _Int8ParamsT: w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names return ( getattr(layer, w_q), diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 3c0ee8323c55..2e00775b90d6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from typing import TypeVar import torch @@ -13,6 +14,7 @@ CPUScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassFP8ScaledMMLinearKernel, CutlassScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( @@ -26,6 +28,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, @@ -42,15 +46,16 @@ logger = init_logger(__name__) # in priority/performance order (when available) -_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { +_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } -_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { - PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], +# in priority/performance order (when available) +_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = { + PlatformEnum.CUDA: [CutlassFP8ScaledMMLinearKernel], PlatformEnum.ROCM: [ ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, @@ -59,21 +64,25 @@ ], } +_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True) +_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) + def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, - possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]], - module_name: str, + config: _KernelConfigT, + possible_kernels: dict[PlatformEnum, list[type[_KernelT]]], compute_capability: int | None = None, -) -> type[ScaledMMLinearKernel]: +) -> type[_KernelT]: """ - Choose an ScaledMMLinearKernel that can implement the given config for the + Choose a _KernelT that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of performance. Args: - config (ScaledMMLinearLayerConfig): Description of the linear layer + config (_KernelConfigT): Description of the linear layer to be implemented. + possible_kernels (dict[PlatformEnum, list[_KernelT]]): A + dictionary of platforms and their list list of possible kernels. compute_capability (Optional[int], optional): The compute capability of the target device, if None uses `current_platform` to get the compute capability. Defaults to None. @@ -82,7 +91,7 @@ def choose_scaled_mm_linear_kernel( ValueError: If no kernel can implement the given config. Returns: - type[ScaledMMLinearKernel]: Chosen kernel. + _KernelT: Chosen kernel. """ if compute_capability is None: @@ -115,9 +124,6 @@ def choose_scaled_mm_linear_kernel( can_implement, failure_reason = kernel.can_implement(config) if can_implement: - logger.info_once( - "Selected %s for %s", kernel.__name__, module_name, scope="global" - ) return kernel else: failure_reasons.append( @@ -147,10 +153,51 @@ def init_fp8_linear_kernel( kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, - module_name=module_name, + ) + + logger.info_once( + "Selected %s for %s", + kernel_type.__class__.__name__, + module_name, + scope="global", ) return kernel_type( scaled_mm_linear_kernel_config, layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"], ) + + +def init_int8_linear_kernel( + is_channelwise: bool, + is_static_input_scheme: bool, + input_symmetric: bool, + module_name: str, +) -> Int8ScaledMMLinearKernel: + config = Int8ScaledMMLinearLayerConfig( + is_channelwise=is_channelwise, + is_static_input_scheme=is_static_input_scheme, + input_symmetric=input_symmetric, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + config, _POSSIBLE_INT8_KERNELS, + ) + + logger.info_once( + "Selected %s for %s", + kernel_type.__class__.__name__, + module_name, + scope="global", + ) + + return kernel_type( + config, + layer_param_names=[ + "weight", + "weight_scale", + "input_scale", + "input_zero_point", + "azp_adj", + ], + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index ea8db2456f86..a7a7726bae0e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,11 +7,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_INT8_KERNELS, - choose_scaled_mm_linear_kernel, -) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - Int8ScaledMMLinearLayerConfig, + init_int8_linear_kernel, ) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.parameter import ( @@ -51,15 +47,10 @@ def create_weights( ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig( + self.kernel = init_int8_linear_kernel( is_channelwise=(self.qscheme == "per_channel"), is_static_input_scheme=(self.is_static_input_scheme is True), input_symmetric=(self.input_symmetric is True), - ) - - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - possible_kernels=_POSSIBLE_INT8_KERNELS, module_name=self.__class__.__name__, ) @@ -119,18 +110,6 @@ def create_weights( if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - layer_param_names = [ - "weight", - "weight_scale", - "input_scale", - "input_zero_point", - "azp_adj", - ] - - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names - ) - # Checkpoints are serialized in quark format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: From 4ce0ba2df421f26665ff4a603d5775c3cef85e4c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 1 Nov 2025 10:01:13 +0000 Subject: [PATCH 12/36] format Signed-off-by: vllmellm --- .../layers/quantization/kernels/scaled_mm/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 2e00775b90d6..08c1ced5f08d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -181,7 +181,8 @@ def init_int8_linear_kernel( ) kernel_type = choose_scaled_mm_linear_kernel( - config, _POSSIBLE_INT8_KERNELS, + config, + _POSSIBLE_INT8_KERNELS, ) logger.info_once( From dd5a70ec71a8d91c8bea0edb2c2d2b16eaddd256 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 1 Nov 2025 16:28:03 +0000 Subject: [PATCH 13/36] update unit tests to use ScaledMMLinearKernels Signed-off-by: vllmellm --- tests/compile/test_functionalization.py | 53 +++++++++++-------- tests/compile/test_fusion.py | 36 ++++++++----- tests/compile/test_fusion_all_reduce.py | 39 +++++++++----- tests/compile/test_fusion_attn.py | 32 +++++++---- tests/compile/test_sequence_parallelism.py | 28 ++++++---- tests/compile/test_silu_mul_quant_fusion.py | 27 ++++++---- tests/utils.py | 11 ++++ vllm/config/structured_outputs.py | 3 +- .../layers/quantization/fbgemm_fp8.py | 4 -- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 2 +- .../kernels/scaled_mm/__init__.py | 2 +- .../quantization/kernels/scaled_mm/utils.py | 4 +- 12 files changed, 152 insertions(+), 89 deletions(-) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 11ae96e930da..4d979f075d78 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -20,8 +20,13 @@ ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -35,21 +40,23 @@ class TestSiluMul(torch.nn.Module): def __init__(self, hidden_size: int = 128): super().__init__() self.silu_and_mul = SiluAndMul() - self.wscale = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) + self.input_scale_ub = None if TEST_FP8: - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: - x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) - return x2 + return self.fp8_linear.apply_weights(self, y) else: return y @@ -81,11 +88,19 @@ def __init__(self, hidden_size=16, intermediate_size=32): torch.nn.init.normal_(self.gate_proj, std=0.02) if TEST_FP8: - self.fp8_linear = Fp8LinearOp(act_quant_static=True) - - self.scale = torch.rand(1, dtype=torch.float32) - self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() - self.wscale = torch.rand(1, dtype=torch.float32) + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) + self.weight = ( + torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() + ) + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) + self.input_scale_ub = None def forward(self, hidden_states, residual): # Reshape input @@ -99,13 +114,9 @@ def forward(self, hidden_states, residual): norm_output, residual_output = self.norm(mm, residual) if TEST_FP8: + self.input_scale = self.input_scale.to(norm_output.device) # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply( - norm_output, - self.w, - self.wscale, - input_scale=self.scale.to(norm_output.device), - ) + fp8_linear_result = self.fp8_linear.apply_weights(self, norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 286f2276367a..ed925a4d55cc 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,19 +18,24 @@ VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity, ) from vllm.platforms import current_platform -from ..utils import override_cutlass_fp8_supported +from ..utils import TestFP8Layer, override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -54,6 +59,8 @@ def __init__( self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR + quant_scale = ScaleDesc(torch.float32, static, group_shape) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: @@ -66,9 +73,12 @@ def __init__( ] with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=static, - act_quant_group_shape=group_shape, + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=static, + act_q_group_shape=group_shape, + weight_quant_strategy=weight_quant_strategy, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) self.enable_rms_norm_custom_op = self.norm[0].enabled() @@ -79,20 +89,20 @@ def forward(self, x): x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + layer1 = TestFP8Layer(self.w[0], self.wscale[0], input_scale=self.scale[0]) + x2 = self.fp8_linear.apply_weights(layer1, y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + layer2 = TestFP8Layer(self.w[1], self.wscale[1], input_scale=self.scale[1]) + x3 = self.fp8_linear.apply_weights(layer2, y2) y3, resid = self.norm[2](x3, resid) # use resid here - x4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + layer3 = TestFP8Layer(self.w[2], self.wscale[2], input_scale=self.scale[2]) + x4 = self.fp8_linear.apply_weights( + layer3, + y3, ) y4, resid = self.norm[3](x4, resid) # use resid here diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 6d0a0ed7d89d..2dc6f8d2f925 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -26,14 +26,19 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, GroupShape, ) from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables -from ..utils import has_module_attribute, multi_gpu_test +from ..utils import TestFP8Layer, has_module_attribute, multi_gpu_test from .backend import TestBackend @@ -81,43 +86,49 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - self.w = [ + self.input_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.weight = [ torch.rand(hidden_size, hidden_size) .to(dtype=current_platform.fp8_dtype()) .t() for _ in range(3) ] - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly z = torch.relu(hidden_states) x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] + layer1 = TestFP8Layer( + self.weight[0], self.weight_scale[0], input_scale=self.input_scale[0] ) + z2 = self.fp8_linear.apply_weights(layer1, y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - z3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + layer2 = TestFP8Layer( + self.weight[1], self.weight_scale[1], input_scale=self.input_scale[1] ) + z3 = self.fp8_linear.apply(layer2, y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - z4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + layer3 = TestFP8Layer( + self.weight[2], self.weight_scale[2], input_scale=self.input_scale[2] ) + z4 = self.fp8_linear.apply(layer3, y3) + x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here return y4 diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index fecb1e2e918f..a6ebf46d98dd 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -28,16 +28,23 @@ set_current_vllm_config, ) from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, kNvfp4Quant, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec +from ..utils import TestFP8Layer + FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -170,11 +177,18 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.quant_key.scale.static, - act_quant_group_shape=self.quant_key.scale.group_shape, + if self.quant_key.scale.group_shape.is_per_tensor(): + weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR + else: + weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL + + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=self.quant_key.scale.static, + act_q_group_shape=self.quant_key.scale.group_shape, + weight_quant_strategy=weight_quant_strategy, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) - hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( "w", @@ -190,12 +204,8 @@ def __init__(self, *args, **kwargs): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - return self.fp8_linear.apply( - input=attn_output, - weight=self.w["weight"], - weight_scale=self.w["wscale"], - input_scale=self.w["scale"], - ) + layer = TestFP8Layer(self.w["weight"], self.w["wscale"], self.w["scale"]) + return self.fp8_linear.apply_weights(layer, attn_output) class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index e909cf7393ad..007339cd86f7 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -27,11 +27,17 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables -from ..utils import multi_gpu_test +from ..utils import TestFP8Layer, multi_gpu_test from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -107,8 +113,13 @@ def __init__(self, hidden_size=16, intermediate_size=32): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(act_quant_static=True) - + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, # which expects a column-major layout. @@ -138,14 +149,9 @@ def forward(self, hidden_states, residual): # layer normalization norm_output, residual_output = self.norm(all_reduce, residual) - # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply( - norm_output, - self.w, - self.wscale, - input_scale=self.scale.to(norm_output.device), - ) + layer = TestFP8Layer(None, None, self.scale.to(norm_output.device)) + fp8_linear_result = self.fp8_linear.apply(layer, norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 0ddb82b7c3fc..2ce52b97f13e 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -24,13 +24,18 @@ set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, kNvfp4Quant, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, ) from vllm.platforms import current_platform @@ -50,22 +55,26 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() - self.wscale = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) + self.input_scale_ub = None + self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) + x2 = self.fp8_linear.apply_weights(self, y) return x2 def ops_in_model_before(self): diff --git a/tests/utils.py b/tests/utils.py index af4ce6ebaeda..bb3bbc750350 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1411,3 +1411,14 @@ def flat_product(*iterables: Iterable[Any]): for element in itertools.product(*iterables): normalized = (e if isinstance(e, tuple) else (e,) for e in element) yield tuple(itertools.chain(*normalized)) + + +class TestFP8Layer(torch.nn.Module): + """Helper class for ScaledMMLinearKernels.""" + + def __init__(self, weight, weight_scale, input_scale): + super().__init__() + self.weight_scale = weight_scale + self.weight = weight + self.input_scale = input_scale + self.input_scale_ub = None diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 85b6e42264a4..eb1cc7220b8f 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from typing import Any, Literal, Self +from typing import Any, Literal from pydantic import model_validator from pydantic.dataclasses import dataclass +from typing_extensions import Self from vllm.config.utils import config diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index a7b8e6ddda71..5fa419ebaa91 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -33,7 +33,6 @@ is_layer_skipped, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, ) @@ -97,9 +96,6 @@ def get_quant_method( class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN - ) self.out_dtype = torch.get_default_dtype() self.fp8_linear_kernel = init_fp8_linear_kernel( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 329078f0a489..f3bff8cae0ef 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -41,7 +41,7 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): weight_quant_strategy: ScaledMMLinearQuantStrategy activation_group_shape: GroupShape - out_dtype: torch.dtype + out_dtype: torch.dtype | None _FP8ParamsT = tuple[ diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 08c1ced5f08d..901f0649a6d4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -64,7 +64,7 @@ ], } -_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True) +_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) _KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index 8323690817d6..62bbacbc782c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -7,11 +7,9 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.platforms import current_platform -FP8ScaledMMCallBack = Callable[..., torch.Tensor] - def apply_weights_fp8( - scaled_mm_func: FP8ScaledMMCallBack, + scaled_mm_func: Callable[..., torch.Tensor], quant_fp8_func: QuantFP8, w: torch.Tensor, x: torch.Tensor, From 52ff5374592572128e376a4dd1fc9bd7a6815fdb Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 1 Nov 2025 16:28:18 +0000 Subject: [PATCH 14/36] update modelopt path Signed-off-by: vllmellm --- .../layers/quantization/modelopt.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 37b682984fc3..f478cd319e66 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -37,6 +37,12 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, @@ -68,7 +74,6 @@ swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, requantize_with_max_scale, ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter @@ -254,8 +259,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=None, + module_name=self.__class__.__name__, ) def create_weights( @@ -323,13 +332,7 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8MoEMethod(FusedMoEMethodBase): From b13c4bb25c5af5d7348ce04a9f2b622330ce3fb6 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 1 Nov 2025 16:30:32 +0000 Subject: [PATCH 15/36] remove FP8LinearOps Signed-off-by: vllmellm --- .../layers/quantization/utils/w8a8_utils.py | 355 ------------------ 1 file changed, 355 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 380431e86435..f2d8eecdc68e 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,19 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch from packaging import version from vllm import _custom_ops as ops -from vllm import envs -from vllm.config import CompilationMode, get_current_vllm_config -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer -from vllm.utils.torch_utils import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -143,354 +136,6 @@ def maybe_create_device_identity(): TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) -def cutlass_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - return output.view(*output_shape) - - -def flashinfer_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - return flashinfer_scaled_fp8_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - - -def rocm_per_tensor_w8a8_scaled_mm_impl( - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, -) -> torch.Tensor: - from vllm.platforms.rocm import on_mi3xx - - if ( - envs.VLLM_ROCM_USE_SKINNY_GEMM - and on_mi3xx() - and qinput.shape[0] == 1 - and qinput.shape[1] % 16 == 0 - and ((bias is None) or (bias.dtype == out_dtype)) - ): - output = ops.wvSplitKQ( - weight.t(), - qinput, - out_dtype, - scale_a, - scale_b, - current_platform.get_cu_count(), - bias, - ) - else: - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias, - ) - return output - - -def rocm_per_tensor_w8a8_scaled_mm_fake( - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, -) -> torch.Tensor: - return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) - - -def rocm_per_tensor_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, -) -> torch.Tensor: - output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias - ) - return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) - - -direct_register_custom_op( - op_name="rocm_per_tensor_w8a8_scaled_mm_impl", - op_func=rocm_per_tensor_w8a8_scaled_mm_impl, - fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, -) - - -def torch_per_tensor_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, -) -> torch.Tensor: - output = torch._scaled_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) - - -def torch_per_token_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM - # when using it. - # For now it has only been validated on ROCm platform. - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using - # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. - # - # For CUDA platform please validate if the torch._scaled_mm supports - # rowwise scaled GEMM before using it - - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b.t(), - bias=bias, - ) - - output = torch.narrow(output, 0, 0, qinput.shape[0]) - output = output.view(*output_shape) - return output - - -def torch_channelwise_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Use unfused DQ due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm( - qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, qinput.shape[0]) - x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * scale_b.t() - if bias is not None: - output = output + bias - return output.to(out_dtype).view(*output_shape) - - -def dispatch_w8a8_scaled_mm( - preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool -) -> Callable[..., torch.Tensor]: - if per_tensor_weights and per_tensor_activations: - if preferred_backend == "rocm": - return rocm_per_tensor_w8a8_scaled_mm - if preferred_backend == "flashinfer": - return flashinfer_w8a8_scaled_mm - if preferred_backend == "cutlass": - return cutlass_w8a8_scaled_mm - return torch_per_tensor_w8a8_scaled_mm - - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if preferred_backend == "cutlass" or preferred_backend == "flashinfer": - return cutlass_w8a8_scaled_mm - - # If torch.scaled_mm supports per-channel (weights) per-token (inputs) - if ( - not per_tensor_weights - and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM - ): - return torch_per_token_w8a8_scaled_mm - # Normally, torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - return torch_channelwise_w8a8_scaled_mm - - -# TODO(luka): follow similar pattern for marlin and block-fp8-linear -# https://github.com/vllm-project/vllm/issues/14397 -class Fp8LinearOp: - """ - This class executes a FP8 linear layer using cutlass if supported and - torch.scaled_mm otherwise. - It needs to be a class instead of a method so that config can be read - in the __init__ method, as reading config is not allowed inside forward. - """ - - def __init__( - self, - act_quant_static: bool, - act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: bool | None = None, - ): - if current_platform.is_rocm(): - self.preferred_backend = "rocm" - elif current_platform.is_cuda() and cutlass_fp8_supported(): - if has_flashinfer() and current_platform.has_device_capability(100): - self.preferred_backend = "flashinfer" - else: - self.preferred_backend = "cutlass" - else: - self.preferred_backend = "torch" - - # Note: we pad the input because torch._scaled_mm is more performant - # for matrices with batch dimension > 16. - # This could change in the future. - # We also don't pad when using torch.compile, - # as it breaks with dynamic shapes. - if pad_output is None: - config = get_current_vllm_config().compilation_config - pad_output = ( - config.mode < CompilationMode.VLLM_COMPILE - and self.preferred_backend == "torch" - ) - - self.output_padding = 17 if pad_output else None - self.act_quant_static = act_quant_static - self.act_quant_group_shape = act_quant_group_shape - self.quant_fp8 = QuantFP8( - static=act_quant_static, - group_shape=act_quant_group_shape, - num_token_padding=self.output_padding, - ) - - def apply( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - out_dtype: torch.dtype | None = None, - input_scale: torch.Tensor | None = None, - input_scale_ub: torch.Tensor | None = None, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[1]] - - if out_dtype is None: - out_dtype = input.dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - if input.dtype != current_platform.fp8_dtype(): - qinput, x_scale = self.quant_fp8( - input_2d, - input_scale, - input_scale_ub, - ) - else: - qinput, x_scale = input_2d, input_scale - - # Must have dim() conditions - # In per-token quant scenario, when the number of token is 1, - # the scale will only have 1 elements. - # Without checking the dim(), - # we cannot distingushes between per-tensor and per-token quant. - # Example: - # When the number of token is 1, per-token scale is [[1]] - # When per-tensor scale is [1] or (). - per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 - per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 - - # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.preferred_backend, per_tensor_weights, per_tensor_activations - ) - - return w8a8_scaled_mm_func( - qinput=qinput, - weight=weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - output_shape=output_shape, - ) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, From 77940096618976a32fdcedb4616220e3da18fceb Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 3 Nov 2025 07:09:52 +0000 Subject: [PATCH 16/36] add missing arg Signed-off-by: vllmellm --- .../layers/quantization/kernels/scaled_mm/flash_infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index e33b30532204..9b0ac38db5e3 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -22,6 +22,7 @@ def flashinfer_w8a8_scaled_mm( As: torch.Tensor, Bs: torch.Tensor, bias: torch.Tensor, + output_shape: list, ) -> torch.Tensor: return flashinfer_scaled_fp8_mm( A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias From a8010c7b1c83aa884a3212925c442d37204fb14e Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 3 Nov 2025 08:02:45 +0000 Subject: [PATCH 17/36] flash_infer missing out dtype bug fix Signed-off-by: vllmellm --- .../layers/quantization/kernels/scaled_mm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index 62bbacbc782c..e5ab5ad4d47c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -17,7 +17,7 @@ def apply_weights_fp8( x_s: torch.Tensor, bias: torch.Tensor, x_s_ub: torch.Tensor | None, - maybe_out_dtype: torch.dtype | None, + maybe_out_dtype: torch.dtype | None = None, ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_s computed from x. From f5e6cd9695848739d56acc46f89b29db8e0769bf Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 12:11:13 +0000 Subject: [PATCH 18/36] prefer QuantKey over ScaledMMLinearQuantStrategy Signed-off-by: vllmellm --- tests/compile/test_functionalization.py | 19 ++++----- tests/compile/test_fusion.py | 26 ++++++------ tests/compile/test_fusion_all_reduce.py | 14 +++---- tests/compile/test_fusion_attn.py | 14 ++----- tests/compile/test_sequence_parallelism.py | 12 +++--- tests/compile/test_silu_mul_quant_fusion.py | 12 ++---- .../schemes/compressed_tensors_w8a8_fp8.py | 39 +++++++++--------- .../layers/quantization/fbgemm_fp8.py | 17 +++----- .../model_executor/layers/quantization/fp8.py | 20 +++++----- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 30 +++++--------- .../kernels/scaled_mm/__init__.py | 13 +++--- .../kernels/scaled_mm/flash_infer.py | 7 ++-- .../quantization/kernels/scaled_mm/pytorch.py | 19 +++++---- .../quantization/kernels/scaled_mm/rocm.py | 21 +++++----- .../layers/quantization/modelopt.py | 12 ++---- .../layers/quantization/ptpc_fp8.py | 16 +++----- .../quark/schemes/quark_w8a8_fp8.py | 40 +++++++++---------- .../layers/quantization/utils/quant_utils.py | 3 ++ 18 files changed, 147 insertions(+), 187 deletions(-) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 4d979f075d78..a40f8beccdc2 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -23,10 +23,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -37,6 +36,8 @@ class TestSiluMul(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size: int = 128): super().__init__() self.silu_and_mul = SiluAndMul() @@ -46,9 +47,8 @@ def __init__(self, hidden_size: int = 128): if TEST_FP8: self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -74,6 +74,8 @@ def ops_not_in_model(self): class TestFusedAddRMSNorm(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size @@ -89,9 +91,8 @@ def __init__(self, hidden_size=16, intermediate_size=32): if TEST_FP8: self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index ed925a4d55cc..6270344c2eb3 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -21,9 +21,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -59,10 +56,16 @@ def __init__( self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN - weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR - quant_scale = ScaleDesc(torch.float32, static, group_shape) - self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + act_quant_scale = ScaleDesc(torch.float32, static, group_shape) + w_quant_scale = ScaleDesc(torch.float32, True, group_shape) + self.activation_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=act_quant_scale, symmetric=True + ) + self.weight_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=w_quant_scale, symmetric=True + ) + if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: @@ -74,9 +77,8 @@ def __init__( with override_cutlass_fp8_supported(not cuda_force_torch): self.fp8_linear = init_fp8_linear_kernel( - act_q_static=static, - act_q_group_shape=group_shape, - weight_quant_strategy=weight_quant_strategy, + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -110,13 +112,13 @@ def forward(self, x): def ops_in_model_after(self): return [ - FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], - FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], + FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)], + FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)], ] def ops_in_model_before(self): return ( - [QUANT_OPS[self.quant_key]] + [QUANT_OPS[self.activation_quant_key]] if self.enable_quant_fp8_custom_op else [torch.ops.aten.reciprocal] ) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 2dc6f8d2f925..5e2c46f8ea91 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -29,11 +29,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - GroupShape, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, ) from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables @@ -80,6 +77,8 @@ def ops_in_model_after(self): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -95,9 +94,8 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): ] self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index a6ebf46d98dd..9068e304f551 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -31,9 +31,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -177,18 +174,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.quant_key.scale.group_shape.is_per_tensor(): - weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR - else: - weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL - self.fp8_linear = init_fp8_linear_kernel( - act_q_static=self.quant_key.scale.static, - act_q_group_shape=self.quant_key.scale.group_shape, - weight_quant_strategy=weight_quant_strategy, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) + hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( "w", diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 007339cd86f7..f579815338a9 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -30,10 +30,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables @@ -101,6 +100,8 @@ def ops_in_model(self): class TestQuantModel(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size @@ -114,9 +115,8 @@ def __init__(self, hidden_size=16, intermediate_size=32): torch.nn.init.normal_(self.gate_proj, std=0.02) self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 2ce52b97f13e..20e7c2955d01 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -27,11 +27,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, kFp8StaticTensorSym, kNvfp4Quant, ) @@ -52,6 +48,8 @@ def is_nvfp4_supported(): class TestSiluMulFp8QuantModel(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() @@ -62,13 +60,11 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): with override_cutlass_fp8_supported(not cuda_force_torch): self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) - self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 58ea30edcd63..0d14c13180ab 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -14,9 +14,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - QUANT_STRATEGY_MAP, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, @@ -29,7 +26,11 @@ process_fp8_weight_tensor_strategy, validate_fp8_block_shape, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, maybe_create_device_identity, @@ -48,6 +49,12 @@ QuantizationStrategy.TENSOR: PerTensorScaleParameter, } +STATIC_QUANT = True +DYNAMIC_QUANT = False +quant_keys = { + STATIC_QUANT: (kFp8StaticTensorSym, kFp8StaticTensorSym), + DYNAMIC_QUANT: (kFp8DynamicTokenSym, kFp8StaticTensorSym), +} logger = init_logger(__name__) @@ -57,22 +64,13 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.weight_block_size = self.weight_quant.block_structure - if self.weight_block_size is not None: - self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) - else: - self.act_q_group_shape = ( - GroupShape.PER_TENSOR - if is_static_input_scheme - else GroupShape.PER_TOKEN - ) - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() if self.weight_block_size is not None: assert not self.is_static_input_scheme + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(*self.weight_block_size), act_quant_group_shape=self.act_q_group_shape, @@ -80,12 +78,11 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=self.is_static_input_scheme, - act_q_group_shape=self.act_q_group_shape, - weight_quant_strategy=weight_quant_strategy, - out_dtype=self.out_dtype, + activation_quant_key, weight_quant_key = quant_keys[is_static_input_scheme] + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -204,4 +201,4 @@ def apply_weights( bias=bias, ) - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 5fa419ebaa91..c19dd708b233 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -21,16 +21,13 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped, + kFp8DynamicTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( maybe_create_device_identity, @@ -97,12 +94,10 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config self.out_dtype = torch.get_default_dtype() - - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=False, - act_q_group_shape=GroupShape.PER_TOKEN, - weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, - out_dtype=self.out_dtype, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8DynamicTokenSym, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -194,4 +189,4 @@ def apply( bias=bias, ) - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 48697e3849e0..c04bcef7bb0b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -45,10 +45,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 - FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -82,6 +78,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, + kFp8DynamicTensorSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, @@ -380,8 +379,10 @@ def __init__(self, quant_config: Fp8Config): # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): self.act_q_group_shape = GroupShape.PER_TOKEN + self.activation_quant_key = kFp8StaticTokenSym else: self.act_q_group_shape = GroupShape.PER_TENSOR + self.activation_quant_key = kFp8DynamicTensorSym if self.block_quant: assert not self.act_q_static @@ -393,11 +394,10 @@ def __init__(self, quant_config: Fp8Config): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=self.act_q_static, - act_q_group_shape=self.act_q_group_shape, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, - out_dtype=self.out_dtype, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -684,7 +684,7 @@ def apply( bias=bias, ) - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) class Fp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index f3bff8cae0ef..a8a2fc245f62 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -4,43 +4,32 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass -from enum import Enum from typing import Generic, TypeVar import torch -from compressed_tensors.quantization import QuantizationStrategy from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape - - -class ScaledMMLinearQuantStrategy(Enum): - TENSOR = "tensor" - CHANNEL = "channel" - BLOCK = "block" - - -QUANT_STRATEGY_MAP = { - QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR, - QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL, -} +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) @dataclass class ScaledMMLinearLayerConfig: - is_static_input_scheme: bool + pass @dataclass class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): + is_static_input_scheme: bool is_channelwise: bool input_symmetric: bool @dataclass class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): - weight_quant_strategy: ScaledMMLinearQuantStrategy - activation_group_shape: GroupShape + weight_quant_key: QuantKey + activation_quant_key: QuantKey out_dtype: torch.dtype | None @@ -103,9 +92,10 @@ class FP8ScaledMMLinearKernel( def __init__( self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] ) -> None: + act_scale_descriptor = c.activation_quant_key.scale self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=c.activation_group_shape, + static=act_scale_descriptor.static, + group_shape=act_scale_descriptor.group_shape, num_token_padding=self.get_ouput_padding(), ) super().__init__(c, layer_param_names) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 901f0649a6d4..b36b77109e92 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -32,7 +32,6 @@ Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonScaledMMLinearKernel, @@ -40,7 +39,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms import PlatformEnum, current_platform logger = init_logger(__name__) @@ -137,16 +136,14 @@ def choose_scaled_mm_linear_kernel( def init_fp8_linear_kernel( - act_q_static: bool, - act_q_group_shape: GroupShape, - weight_quant_strategy: ScaledMMLinearQuantStrategy, + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, out_dtype: torch.dtype, module_name: str, ) -> FP8ScaledMMLinearKernel: scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( - is_static_input_scheme=act_q_static, - weight_quant_strategy=weight_quant_strategy, - activation_group_shape=act_q_group_shape, + weight_quant_key=weight_quant_key, + activation_quant_key=activation_quant_key, out_dtype=out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 9b0ac38db5e3..3bac71950dda 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -9,7 +9,6 @@ from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, ) from .utils import apply_weights_fp8 @@ -39,10 +38,10 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if not current_platform.is_cuda(): return ( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index c0466e840fc0..7c4c64215a8e 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -10,7 +10,6 @@ from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, ) from .utils import apply_weights_fp8 @@ -143,10 +142,10 @@ def get_ouput_padding(self) -> int | None: class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if not (per_tensor_activation_scales and per_tensor_weight_scales): return ( @@ -183,10 +182,10 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if per_tensor_activation_scales or per_tensor_weight_scales: return ( @@ -237,10 +236,10 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if per_tensor_activation_scales and per_tensor_weight_scales: return ( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 63744337a7e5..26463a19c6f4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -11,7 +11,6 @@ from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, ) from .utils import apply_weights_fp8 @@ -72,18 +71,18 @@ def rocm_per_tensor_float_w8a8_scaled_mm( bias: torch.Tensor, output_shape: list[int], ) -> torch.Tensor: - output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl( A, B, out_dtype, As, Bs, bias ) return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", - op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, - fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, - ) +# if current_platform.is_rocm(): +direct_register_custom_op( + op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, + fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, +) class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): @@ -95,10 +94,10 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non # TODO: check if this causes an issue on non-ROCM platforms from vllm.platforms.rocm import on_mi3xx - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if not current_platform.is_rocm(): return ( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index f478cd319e66..53b25af44035 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, @@ -68,9 +65,9 @@ prepare_moe_fp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, cutlass_fp4_supported, is_layer_skipped, + kFp8StaticTensorSym, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -260,10 +257,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, - out_dtype=None, + activation_quant_key=kFp8StaticTensorSym, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 2634bbd4bd87..c102c52bbe3f 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -19,12 +19,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped, + kFp8DynamicTokenSym, ) from vllm.platforms import current_platform @@ -103,11 +100,10 @@ def __init__(self, quant_config: PTPCFp8Config): ) super().__init__(quant_config=quant_config) # Force weight quantization - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=False, - act_q_group_shape=GroupShape.PER_TOKEN, - weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, - out_dtype=self.out_dtype, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8DynamicTokenSym, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -135,4 +131,4 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 6fff44900007..343539c10fa8 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -11,11 +11,13 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, @@ -31,11 +33,6 @@ logger = init_logger(__name__) -QUANT_STRATEGY_MAP = { - "per_tensor": ScaledMMLinearQuantStrategy.TENSOR, - "per_channel": ScaledMMLinearQuantStrategy.CHANNEL, -} - class QuarkW8A8Fp8(QuarkScheme): def __init__( @@ -48,11 +45,16 @@ def __init__( self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - per_token = ( + per_token_activation = ( not self.is_static_input_scheme and self.input_qscheme == "per_channel" ) - self.act_quant_group_shape = ( - GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR + per_token_weight = self.weight_qscheme == "per_channel" + + self.activation_quant_key = ( + kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym + ) + self.weight_quant_key = ( + kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym ) self.out_dtype = torch.get_default_dtype() @@ -103,7 +105,7 @@ def process_weights_after_loading(self, layer) -> None: layer.input_scale = Parameter(input_scale, requires_grad=False) else: weight_scale = layer.weight_scale.data - if self.act_quant_group_shape == GroupShape.PER_TOKEN: + if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN: weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter @@ -174,12 +176,10 @@ def create_weights( layer.register_parameter("input_scale_ub", None) - weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=self.is_static_input_scheme, - act_q_group_shape=self.act_quant_group_shape, - weight_quant_strategy=weight_quant_strategy, - out_dtype=self.out_dtype, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -189,4 +189,4 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d056d3404385..2c8a614c9e71 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -109,6 +109,9 @@ def __str__(self): kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) +kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN) +kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True) + kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) From a76f7bb90c0193a0f416bb8799a62c5cc0841b9b Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 12:13:02 +0000 Subject: [PATCH 19/36] rename flash_infer.py to flashinfer.py Signed-off-by: vllmellm --- .../kernels/scaled_mm/{flash_infer.py => flashinfer.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vllm/model_executor/layers/quantization/kernels/scaled_mm/{flash_infer.py => flashinfer.py} (100%) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py rename to vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py From f10171cb3d1b4e42f312ae1e079926ce1d508c42 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 12:22:49 +0000 Subject: [PATCH 20/36] correct minimum capability req for channelwise torch Signed-off-by: vllmellm --- .../layers/quantization/kernels/scaled_mm/pytorch.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 7c4c64215a8e..40e55cc97392 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -230,9 +230,6 @@ def apply_weights( class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): - @classmethod - def get_min_capability(cls) -> int: - return 94 @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: From fb72ec8218e946027d21c81ef96392ce0bd1a2bd Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 12:23:38 +0000 Subject: [PATCH 21/36] add missing kernels for cuda dispatch Signed-off-by: vllmellm --- .../quantization/kernels/scaled_mm/__init__.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index b36b77109e92..67d077289578 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -17,6 +17,10 @@ CutlassFP8ScaledMMLinearKernel, CutlassScaledMMLinearKernel, ) + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferScaledMMLinearKernel +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, @@ -54,7 +58,13 @@ # in priority/performance order (when available) _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = { - PlatformEnum.CUDA: [CutlassFP8ScaledMMLinearKernel], + PlatformEnum.CUDA: [ + FlashInferScaledMMLinearKernel, + CutlassFP8ScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], PlatformEnum.ROCM: [ ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, From 93fb7071f5ecbdde0c8c03a68f3b0fc692d5e8f0 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 13:10:29 +0000 Subject: [PATCH 22/36] reduce test boilerplate Signed-off-by: vllmellm --- tests/compile/test_functionalization.py | 27 ++++---------- tests/compile/test_fusion.py | 30 ++++++---------- tests/compile/test_fusion_all_reduce.py | 36 ++++++++----------- tests/compile/test_fusion_attn.py | 15 +++----- tests/compile/test_sequence_parallelism.py | 16 +++------ tests/compile/test_silu_mul_quant_fusion.py | 21 +++++------ tests/utils.py | 40 +++++++++++++++++++-- 7 files changed, 87 insertions(+), 98 deletions(-) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index a40f8beccdc2..a10645227383 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -20,14 +20,12 @@ ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform +from ..utils import TestFP8Layer from .backend import TestBackend @@ -43,20 +41,14 @@ def __init__(self, hidden_size: int = 128): self.silu_and_mul = SiluAndMul() self.weight_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32) - self.input_scale_ub = None if TEST_FP8: self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) - + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.weight, + self.weight_scale, self.input_scale) def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: - return self.fp8_linear.apply_weights(self, y) + return self.fp8_linear(y) else: return y @@ -90,18 +82,13 @@ def __init__(self, hidden_size=16, intermediate_size=32): torch.nn.init.normal_(self.gate_proj, std=0.02) if TEST_FP8: - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) self.weight = ( torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() ) self.weight_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32) - self.input_scale_ub = None + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, + self.weight, self.weight_scale, self.input_scale) def forward(self, hidden_states, residual): # Reshape input @@ -117,7 +104,7 @@ def forward(self, hidden_states, residual): if TEST_FP8: self.input_scale = self.input_scale.to(norm_output.device) # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply_weights(self, norm_output) + fp8_linear_result = self.fp8_linear(norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6270344c2eb3..e627c67288cf 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,9 +18,7 @@ VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -76,36 +74,30 @@ def __init__( ] with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.activation_quant_key, - weight_quant_key=self.weight_quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.fp8_linear_1 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, + self.w[0], self.wscale[0], self.scale[0]) + self.fp8_linear_2 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, + self.w[1], self.wscale[1], self.scale[1]) + self.fp8_linear_3 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, + self.w[2], self.wscale[2], self.scale[2]) self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() def forward(self, x): # avoid having graph input be an arg to a pattern directly x = resid = torch.relu(x) y = self.norm[0](x) - layer1 = TestFP8Layer(self.w[0], self.wscale[0], input_scale=self.scale[0]) - x2 = self.fp8_linear.apply_weights(layer1, y) + x2 = self.fp8_linear_1(y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - layer2 = TestFP8Layer(self.w[1], self.wscale[1], input_scale=self.scale[1]) - x3 = self.fp8_linear.apply_weights(layer2, y2) + x3 = self.fp8_linear_2(y2) y3, resid = self.norm[2](x3, resid) # use resid here - layer3 = TestFP8Layer(self.w[2], self.wscale[2], input_scale=self.scale[2]) - x4 = self.fp8_linear.apply_weights( - layer3, - y3, - ) + x4 = self.fp8_linear_3(y3) y4, resid = self.norm[3](x4, resid) # use resid here return y4 diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 5e2c46f8ea91..161d703b79f1 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -26,9 +26,7 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -93,12 +91,14 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): for _ in range(3) ] - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.fp8_linear_1 = TestFP8Layer(self.quant_key,self.quant_key, + self.weight[0],self.wscale[0], input_scale=self.input_scale[0]) + + self.fp8_linear_2 = TestFP8Layer(self.quant_key,self.quant_key, + self.weight[1],self.wscale[1], input_scale=self.input_scale[1]) + + self.fp8_linear_3 = TestFP8Layer(self.quant_key, self.quant_key, + self.weight[2], self.wscale[2],input_scale=self.input_scale[2]) def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly @@ -106,26 +106,18 @@ def forward(self, hidden_states): x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - layer1 = TestFP8Layer( - self.weight[0], self.weight_scale[0], input_scale=self.input_scale[0] - ) - z2 = self.fp8_linear.apply_weights(layer1, y) + + z2 = self.fp8_linear_1(y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - layer2 = TestFP8Layer( - self.weight[1], self.weight_scale[1], input_scale=self.input_scale[1] - ) - z3 = self.fp8_linear.apply(layer2, y2) + z3 = self.fp8_linear_2(y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - layer3 = TestFP8Layer( - self.weight[2], self.weight_scale[2], input_scale=self.input_scale[2] - ) - z4 = self.fp8_linear.apply(layer3, y3) + z4 = self.fp8_linear_3(y3) x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here @@ -138,7 +130,7 @@ def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, torch.ops._C.static_scaled_fp8_quant.default - if self.fp8_linear.quant_fp8.enabled() + if self.fp8_linear.is_quant_fp8_enabled() else torch.ops.aten.reciprocal.default, ] diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 9068e304f551..1762af27d190 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -28,9 +28,7 @@ set_current_vllm_config, ) from vllm.forward_context import get_forward_context, set_forward_context -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -174,12 +172,6 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( @@ -192,12 +184,13 @@ def __init__(self, *args, **kwargs): "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), }, ) + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.w["weight"], + self.w["wscale"], self.w["scale"]) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - layer = TestFP8Layer(self.w["weight"], self.w["wscale"], self.w["scale"]) - return self.fp8_linear.apply_weights(layer, attn_output) + return self.fp8_linear(attn_output) class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index f579815338a9..0e422f4ee132 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -27,9 +27,7 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -114,18 +112,15 @@ def __init__(self, hidden_size=16, intermediate_size=32): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, # which expects a column-major layout. self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.wscale = torch.rand(1, dtype=torch.float32) + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, + self.w, self.wscale, self.scale) + def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph @@ -150,8 +145,7 @@ def forward(self, hidden_states, residual): # layer normalization norm_output, residual_output = self.norm(all_reduce, residual) # scaled_mm with static input quantization - layer = TestFP8Layer(None, None, self.scale.to(norm_output.device)) - fp8_linear_result = self.fp8_linear.apply(layer, norm_output) + fp8_linear_result = self.fp8_linear(norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 20e7c2955d01..6e6f54a7fbb2 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -24,9 +24,7 @@ set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, @@ -36,7 +34,7 @@ ) from vllm.platforms import current_platform -from ..utils import override_cutlass_fp8_supported +from ..utils import TestFP8Layer, override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -55,22 +53,19 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): self.silu_and_mul = SiluAndMul() self.weight_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32) - self.input_scale_ub = None self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, + self.weight, self.weight_scale, self.input_scale) + + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply_weights(self, y) + x2 = self.fp8_linear(y) return x2 def ops_in_model_before(self): diff --git a/tests/utils.py b/tests/utils.py index bb3bbc750350..5c2e10f47318 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,6 +49,8 @@ from vllm.utils.mem_constants import GB_bytes from vllm.utils.network_utils import get_open_port from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.model_executor.layers.quantization.kernels.scaled_mm import init_fp8_linear_kernel +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey if current_platform.is_rocm(): from amdsmi import ( @@ -1414,11 +1416,45 @@ def flat_product(*iterables: Iterable[Any]): class TestFP8Layer(torch.nn.Module): - """Helper class for ScaledMMLinearKernels.""" + """ + Test helper class for evaluating FP8 linear operations with quantization. + + It supports configurable activation and weight quantization parameters, + and provides a forward method that applies the FP8 linear transformation + with optional bias. - def __init__(self, weight, weight_scale, input_scale): + Args: + activation_quant_key (QuantKey): Key for activation quantization configuration. + weight_quant_key (QuantKey): Key for weight quantization configuration. + weight (torch.Tensor): Weight tensor for linear transformation. + weight_scale (torch.Tensor): Per-tensor or per-group scale for weights. + input_scale (torch.Tensor): Scale tensor for input quantization. + out_dtype (torch.dtype, optional): Output tensor data type. Defaults to torch.get_default_dtype(). + """ + def __init__(self, + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, + weight:torch.Tensor, + weight_scale:torch.Tensor, + input_scale:torch.Tensor, + out_dtype: torch.dtype = torch.get_default_dtype() + ): super().__init__() self.weight_scale = weight_scale self.weight = weight self.input_scale = input_scale self.input_scale_ub = None + + self.kernel = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=out_dtype, + module_name=self.__class__.__name__, + ) + + def is_quant_fp8_enabled(self) -> bool: + return self.kernel.quant_fp8.enabled() + + def forward(self, y: torch.Tensor, bias: torch.Tensor | None=None) -> torch.Tensor: + return self.kernel.apply_weights(self, y, bias) + From abf597e542956f2ea3ebee66fca5f50c02d1620d Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 14:12:14 +0000 Subject: [PATCH 23/36] fix quant key selection for ct; remove register_paramter calls; format Signed-off-by: vllmellm --- tests/compile/test_functionalization.py | 21 ++++++++--- tests/compile/test_fusion.py | 28 +++++++++++---- tests/compile/test_fusion_all_reduce.py | 33 +++++++++++------ tests/compile/test_fusion_attn.py | 11 +++--- tests/compile/test_sequence_parallelism.py | 7 ++-- tests/compile/test_silu_mul_quant_fusion.py | 11 +++--- tests/utils.py | 36 +++++++++++-------- .../schemes/compressed_tensors_w8a8_fp8.py | 16 ++++++--- .../model_executor/layers/quantization/fp8.py | 2 +- .../kernels/scaled_mm/__init__.py | 7 ++-- .../quark/schemes/quark_w8a8_fp8.py | 2 +- 11 files changed, 114 insertions(+), 60 deletions(-) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index a10645227383..ef8ad92a923e 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -25,8 +25,8 @@ ) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform -from ..utils import TestFP8Layer +from ..utils import TestFP8Layer from .backend import TestBackend TEST_FP8 = current_platform.supports_fp8() @@ -43,8 +43,14 @@ def __init__(self, hidden_size: int = 128): self.input_scale = torch.rand(1, dtype=torch.float32) if TEST_FP8: self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.weight, - self.weight_scale, self.input_scale) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, + ) + def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: @@ -87,8 +93,13 @@ def __init__(self, hidden_size=16, intermediate_size=32): ) self.weight_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32) - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, - self.weight, self.weight_scale, self.input_scale) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, + ) def forward(self, hidden_states, residual): # Reshape input diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index e627c67288cf..a8ac8eb576da 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,7 +18,6 @@ VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm - from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -74,12 +73,27 @@ def __init__( ] with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear_1 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, - self.w[0], self.wscale[0], self.scale[0]) - self.fp8_linear_2 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, - self.w[1], self.wscale[1], self.scale[1]) - self.fp8_linear_3 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, - self.w[2], self.wscale[2], self.scale[2]) + self.fp8_linear_1 = TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[0], + self.wscale[0], + self.scale[0], + ) + self.fp8_linear_2 = TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[1], + self.wscale[1], + self.scale[1], + ) + self.fp8_linear_3 = TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[2], + self.wscale[2], + self.scale[2], + ) self.enable_rms_norm_custom_op = self.norm[0].enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 161d703b79f1..bda2620d3e2f 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -26,7 +26,6 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm - from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -91,14 +90,29 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): for _ in range(3) ] - self.fp8_linear_1 = TestFP8Layer(self.quant_key,self.quant_key, - self.weight[0],self.wscale[0], input_scale=self.input_scale[0]) - - self.fp8_linear_2 = TestFP8Layer(self.quant_key,self.quant_key, - self.weight[1],self.wscale[1], input_scale=self.input_scale[1]) - - self.fp8_linear_3 = TestFP8Layer(self.quant_key, self.quant_key, - self.weight[2], self.wscale[2],input_scale=self.input_scale[2]) + self.fp8_linear_1 = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[0], + self.wscale[0], + input_scale=self.input_scale[0], + ) + + self.fp8_linear_2 = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[1], + self.wscale[1], + input_scale=self.input_scale[1], + ) + + self.fp8_linear_3 = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[2], + self.wscale[2], + input_scale=self.input_scale[2], + ) def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly @@ -106,7 +120,6 @@ def forward(self, hidden_states): x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear_1(y) x2 = tensor_model_parallel_all_reduce(z2) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 1762af27d190..60e01a0b0b63 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -28,7 +28,6 @@ set_current_vllm_config, ) from vllm.forward_context import get_forward_context, set_forward_context - from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -172,7 +171,6 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( "w", @@ -184,8 +182,13 @@ def __init__(self, *args, **kwargs): "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), }, ) - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.w["weight"], - self.w["wscale"], self.w["scale"]) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.w["weight"], + self.w["wscale"], + self.w["scale"], + ) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 0e422f4ee132..fc4d38c8f837 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -27,7 +27,6 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm - from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -117,10 +116,10 @@ def __init__(self, hidden_size=16, intermediate_size=32): # which expects a column-major layout. self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.wscale = torch.rand(1, dtype=torch.float32) - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, - self.w, self.wscale, self.scale) + self.fp8_linear = TestFP8Layer( + self.quant_key, self.quant_key, self.w, self.wscale, self.scale + ) - def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 6e6f54a7fbb2..56b36856f7f2 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -24,7 +24,6 @@ set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul - from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, @@ -56,10 +55,14 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, - self.weight, self.weight_scale, self.input_scale) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, + ) - self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() diff --git a/tests/utils.py b/tests/utils.py index 5c2e10f47318..ba28886e6079 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,10 @@ ) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer @@ -49,8 +53,6 @@ from vllm.utils.mem_constants import GB_bytes from vllm.utils.network_utils import get_open_port from vllm.utils.torch_utils import cuda_device_count_stateless -from vllm.model_executor.layers.quantization.kernels.scaled_mm import init_fp8_linear_kernel -from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey if current_platform.is_rocm(): from amdsmi import ( @@ -1429,32 +1431,36 @@ class TestFP8Layer(torch.nn.Module): weight (torch.Tensor): Weight tensor for linear transformation. weight_scale (torch.Tensor): Per-tensor or per-group scale for weights. input_scale (torch.Tensor): Scale tensor for input quantization. - out_dtype (torch.dtype, optional): Output tensor data type. Defaults to torch.get_default_dtype(). + out_dtype (torch.dtype, optional): Output tensor data type. + Defaults to torch.get_default_dtype(). """ - def __init__(self, - activation_quant_key: QuantKey, - weight_quant_key: QuantKey, - weight:torch.Tensor, - weight_scale:torch.Tensor, - input_scale:torch.Tensor, - out_dtype: torch.dtype = torch.get_default_dtype() - ): + + def __init__( + self, + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor, + out_dtype: torch.dtype | None = None, + ): super().__init__() self.weight_scale = weight_scale self.weight = weight self.input_scale = input_scale self.input_scale_ub = None - + out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype self.kernel = init_fp8_linear_kernel( activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, out_dtype=out_dtype, module_name=self.__class__.__name__, ) - + def is_quant_fp8_enabled(self) -> bool: return self.kernel.quant_fp8.enabled() - def forward(self, y: torch.Tensor, bias: torch.Tensor | None=None) -> torch.Tensor: + def forward( + self, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: return self.kernel.apply_weights(self, y, bias) - diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 0d14c13180ab..a1c60fadce6d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -30,6 +30,7 @@ GroupShape, kFp8DynamicTokenSym, kFp8StaticTensorSym, + kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, @@ -51,9 +52,13 @@ STATIC_QUANT = True DYNAMIC_QUANT = False -quant_keys = { - STATIC_QUANT: (kFp8StaticTensorSym, kFp8StaticTensorSym), - DYNAMIC_QUANT: (kFp8DynamicTokenSym, kFp8StaticTensorSym), +activation_quant_key_mapping = { + STATIC_QUANT: kFp8StaticTensorSym, + DYNAMIC_QUANT: kFp8DynamicTokenSym, +} +weight_quant_key_mapping = { + QuantizationStrategy.CHANNEL: kFp8StaticTokenSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, } logger = init_logger(__name__) @@ -78,7 +83,8 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - activation_quant_key, weight_quant_key = quant_keys[is_static_input_scheme] + activation_quant_key = activation_quant_key_mapping[is_static_input_scheme] + weight_quant_key = weight_quant_key_mapping[self.strategy] self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, @@ -143,7 +149,7 @@ def create_weights( input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) layer.register_parameter("input_scale", input_scale) - layer.register_parameter("input_scale_ub", None) + layer.input_scale_ub = None def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c04bcef7bb0b..6b613b21066a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -451,7 +451,7 @@ def create_weights( weight_loader=weight_loader, ) layer.register_parameter("weight", weight) - layer.register_parameter("input_scale_ub", None) + layer.input_scale_ub = None # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 67d077289578..4a3f74f59126 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -17,9 +17,8 @@ CutlassFP8ScaledMMLinearKernel, CutlassScaledMMLinearKernel, ) - from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( - FlashInferScaledMMLinearKernel + FlashInferScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( ChannelWiseTorchScaledMMLinearKernel, @@ -64,7 +63,7 @@ PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel, - ], + ], PlatformEnum.ROCM: [ ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, @@ -164,7 +163,7 @@ def init_fp8_linear_kernel( logger.info_once( "Selected %s for %s", - kernel_type.__class__.__name__, + kernel_type.__name__, module_name, scope="global", ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 343539c10fa8..819348c5b938 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -174,7 +174,7 @@ def create_weights( input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) - layer.register_parameter("input_scale_ub", None) + layer.input_scale_ub = None self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=self.activation_quant_key, From aaa0d5558707e73c4df98deac2491756725c9699 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 14:40:30 +0000 Subject: [PATCH 24/36] format Signed-off-by: vllmellm --- .../layers/quantization/kernels/scaled_mm/pytorch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 40e55cc97392..10293c445a34 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -230,14 +230,12 @@ def apply_weights( class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): - @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: per_tensor_activation_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor() ) per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() - if per_tensor_activation_scales and per_tensor_weight_scales: return ( False, From 7fb465744c12b0d10372cbb1e514606aebbbcc88 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 07:17:41 +0000 Subject: [PATCH 25/36] implement apply func in base FP8ScaledMMLinearKernel class Signed-off-by: vllmellm --- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 55 ++++++++++++++-- .../quantization/kernels/scaled_mm/cutlass.py | 29 +++------ .../kernels/scaled_mm/flashinfer.py | 37 ++++------- .../quantization/kernels/scaled_mm/pytorch.py | 63 +++---------------- .../quantization/kernels/scaled_mm/rocm.py | 29 +++------ .../quantization/kernels/scaled_mm/utils.py | 49 --------------- 6 files changed, 83 insertions(+), 179 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index a8a2fc245f62..5baa7f73077a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Generic, TypeVar @@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, ) +from vllm.platforms import current_platform @dataclass @@ -98,12 +99,9 @@ def __init__( group_shape=act_scale_descriptor.group_shape, num_token_padding=self.get_ouput_padding(), ) + self.fp8_dtype = current_platform.fp8_dtype() super().__init__(c, layer_param_names) - @abstractmethod - def get_ouput_padding(self) -> int | None: - raise NotImplementedError - @classmethod def get_min_capability(cls) -> int: # lovelace and up @@ -121,6 +119,53 @@ def _get_layer_params(self, layer) -> _FP8ParamsT: getattr(layer, x_s_ub), ) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + scaled_mm_func = self.get_scaled_mm_func() + quant_fp8 = self.quant_fp8 + fp8_dtype = self.fp8_dtype + maybe_out_dtype = self.config.out_dtype + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) + + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_s computed from x. + # If static, layer.input_scale is scalar and x_s is input_scale. + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + output_shape = [*x.shape[:-1], w.shape[1]] + out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != fp8_dtype: + x_2d_q, x_s = quant_fp8( + x_2d, + x_s, + x_s_ub, + ) + return scaled_mm_func( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) + + @abstractmethod + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def get_ouput_padding(self) -> int | None: + raise NotImplementedError + class Int8ScaledMMLinearKernel( ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index fc8893cb7e1b..dbed97078556 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from vllm import _custom_ops as ops @@ -17,7 +19,6 @@ Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ) -from .utils import apply_weights_fp8 def cutlass_w8a8_scaled_mm_fp8( @@ -160,9 +161,6 @@ def apply_weights( class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: - return None - @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cuda(): @@ -174,21 +172,8 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - cutlass_w8a8_scaled_mm_fp8, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return cutlass_w8a8_scaled_mm_fp8 + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py index 3bac71950dda..e816f5d2c156 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from vllm.platforms import current_platform @@ -10,7 +12,6 @@ FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ) -from .utils import apply_weights_fp8 def flashinfer_w8a8_scaled_mm( @@ -29,13 +30,6 @@ def flashinfer_w8a8_scaled_mm( class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: - return None - - @classmethod - def get_min_capability(cls) -> int: - return 100 - @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: per_tensor_activation_scales = ( @@ -71,21 +65,12 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non ) return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - flashinfer_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + @classmethod + def get_min_capability(cls) -> int: + return 100 + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return flashinfer_w8a8_scaled_mm + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 10293c445a34..b7aed6105d10 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from packaging import version @@ -11,7 +13,6 @@ FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ) -from .utils import apply_weights_fp8 # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -155,24 +156,8 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non ) return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - torch_per_tensor_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_per_tensor_w8a8_scaled_mm class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @@ -209,24 +194,8 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - torch_row_wise_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_row_wise_w8a8_scaled_mm class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @@ -245,21 +214,5 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - torch_channelwise_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_channelwise_w8a8_scaled_mm diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 26463a19c6f4..852e0088d0d9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch import vllm.envs as envs @@ -12,7 +14,6 @@ FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ) -from .utils import apply_weights_fp8 def rocm_per_tensor_float_w8a8_scaled_mm_impl( @@ -86,9 +87,6 @@ def rocm_per_tensor_float_w8a8_scaled_mm( class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: - return None - @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: # TODO: check if this causes an issue on non-ROCM platforms @@ -125,21 +123,8 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non ) return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - rocm_per_tensor_float_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_per_tensor_float_w8a8_scaled_mm + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py deleted file mode 100644 index e5ab5ad4d47c..000000000000 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable - -import torch - -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.platforms import current_platform - - -def apply_weights_fp8( - scaled_mm_func: Callable[..., torch.Tensor], - quant_fp8_func: QuantFP8, - w: torch.Tensor, - x: torch.Tensor, - w_s: torch.Tensor, - x_s: torch.Tensor, - bias: torch.Tensor, - x_s_ub: torch.Tensor | None, - maybe_out_dtype: torch.dtype | None = None, -) -> torch.Tensor: - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_s computed from x. - # If static, layer.input_scale is scalar and x_s is input_scale. - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - output_shape = [*x.shape[:-1], w.shape[1]] - - out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = quant_fp8_func( - x_2d, - x_s, - x_s_ub, - ) - - return scaled_mm_func( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, - ) From 56a05cd818d28d24a1e992972b74d27b7621d644 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 07:30:57 +0000 Subject: [PATCH 26/36] add minimal documentation for torch scaled mm base class Signed-off-by: vllmellm --- .../layers/quantization/kernels/scaled_mm/pytorch.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index b7aed6105d10..8c0f0e1d57fb 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -133,6 +133,14 @@ def torch_channelwise_w8a8_scaled_mm( class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + Base class for FP8 linear kernels using Torch. + Each subclass represents a kernel variant for + specific device capabilities and torch versions, + so we split them up and implement + get_min_capability() separately for each. + """ + def get_ouput_padding(self) -> int | None: vllm_config = get_current_vllm_config().compilation_config pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE From 9ff9b44e0dfcf1f6721632798e292417c12b0f94 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 07:39:08 +0000 Subject: [PATCH 27/36] use for loops for fp8 linear layers init in tests Signed-off-by: vllmellm --- tests/compile/test_fusion.py | 41 ++++++++++--------------- tests/compile/test_fusion_all_reduce.py | 41 +++++++++---------------- 2 files changed, 30 insertions(+), 52 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index a8ac8eb576da..6d27d10f687a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -73,45 +73,36 @@ def __init__( ] with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear_1 = TestFP8Layer( - self.activation_quant_key, - self.weight_quant_key, - self.w[0], - self.wscale[0], - self.scale[0], - ) - self.fp8_linear_2 = TestFP8Layer( - self.activation_quant_key, - self.weight_quant_key, - self.w[1], - self.wscale[1], - self.scale[1], - ) - self.fp8_linear_3 = TestFP8Layer( - self.activation_quant_key, - self.weight_quant_key, - self.w[2], - self.wscale[2], - self.scale[2], - ) + self.fp8_linear_layers = [ + TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[i], + self.wscale[i], + input_scale=self.scale[i], + ) + for i in range(3) + ] self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ + 0 + ].is_quant_fp8_enabled() def forward(self, x): # avoid having graph input be an arg to a pattern directly x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear_1(y) + x2 = self.fp8_linear_layers[0](y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear_2(y2) + x3 = self.fp8_linear_layers[1](y2) y3, resid = self.norm[2](x3, resid) # use resid here - x4 = self.fp8_linear_3(y3) + x4 = self.fp8_linear_layers[2](y3) y4, resid = self.norm[3](x4, resid) # use resid here return y4 diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index bda2620d3e2f..a539f4a16038 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -90,29 +90,16 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): for _ in range(3) ] - self.fp8_linear_1 = TestFP8Layer( - self.quant_key, - self.quant_key, - self.weight[0], - self.wscale[0], - input_scale=self.input_scale[0], - ) - - self.fp8_linear_2 = TestFP8Layer( - self.quant_key, - self.quant_key, - self.weight[1], - self.wscale[1], - input_scale=self.input_scale[1], - ) - - self.fp8_linear_3 = TestFP8Layer( - self.quant_key, - self.quant_key, - self.weight[2], - self.wscale[2], - input_scale=self.input_scale[2], - ) + self.fp8_linear_layers = [ + TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[i], + self.wscale[i], + input_scale=self.input_scale[i], + ) + for i in range(3) + ] def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly @@ -120,17 +107,17 @@ def forward(self, hidden_states): x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear_1(y) + z2 = self.fp8_linear_layers[0](y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - z3 = self.fp8_linear_2(y2) + z3 = self.fp8_linear_layers[1](y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - z4 = self.fp8_linear_3(y3) + z4 = self.fp8_linear_layers[2](y3) x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here @@ -143,7 +130,7 @@ def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, torch.ops._C.static_scaled_fp8_quant.default - if self.fp8_linear.is_quant_fp8_enabled() + if self.fp8_linear_layers[0].is_quant_fp8_enabled() else torch.ops.aten.reciprocal.default, ] From cfb476fe539d2d49a97bf1d858fd4bb34b92d6d9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 07:48:30 +0000 Subject: [PATCH 28/36] minor fixes Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 4 ++-- .../model_executor/layers/quantization/fbgemm_fp8.py | 3 ++- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 1 + .../layers/quantization/kernels/scaled_mm/rocm.py | 12 ++++++------ 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index a1c60fadce6d..2cd29e0905d0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -70,10 +70,10 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme self.weight_block_size = self.weight_quant.block_structure - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() if self.weight_block_size is not None: + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() assert not self.is_static_input_scheme self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index c19dd708b233..bcd02554008c 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, kFp8DynamicTokenSym, + kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( maybe_create_device_identity, @@ -96,7 +97,7 @@ def __init__(self, quant_config: FBGEMMFp8Config): self.out_dtype = torch.get_default_dtype() self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym, - weight_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 5baa7f73077a..e2b4f08f6db4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -22,6 +22,7 @@ class ScaledMMLinearLayerConfig: @dataclass class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): + # TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig is_static_input_scheme: bool is_channelwise: bool input_symmetric: bool diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 852e0088d0d9..bcc92ef209af 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -78,12 +78,12 @@ def rocm_per_tensor_float_w8a8_scaled_mm( return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) -# if current_platform.is_rocm(): -direct_register_custom_op( - op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", - op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, - fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, -) +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, + fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, + ) class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): From e47d55b80f60933a565e144c8bf463bf9c8f6214 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 12:13:40 +0000 Subject: [PATCH 29/36] force kernels for tests Signed-off-by: vllmellm --- tests/compile/test_fusion.py | 81 +++++++++++---- tests/utils.py | 6 +- .../kernels/scaled_mm/__init__.py | 99 ++++++++++++------- 3 files changed, 126 insertions(+), 60 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6d27d10f687a..aa4d2c8cf453 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,18 +18,34 @@ VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + ChannelWiseTorchScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_fp8_supported, maybe_create_device_identity, ) from vllm.platforms import current_platform -from ..utils import TestFP8Layer, override_cutlass_fp8_supported +from ..utils import TestFP8Layer from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -44,14 +60,12 @@ def __init__( hidden_size: int, eps: float, static: bool, - cuda_force_torch: bool, + force_kernel: FP8ScaledMMLinearKernel, *args, **kwargs, ): super().__init__(*args, **kwargs) - self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN act_quant_scale = ScaleDesc(torch.float32, static, group_shape) @@ -67,22 +81,30 @@ def __init__( self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: self.scale = [None for _ in range(3)] + + if group_shape == GroupShape.PER_TOKEN: + self.wscale = [ + torch.rand((hidden_size, 1), dtype=torch.float32) for _ in range(3) + ] + else: + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(3) ] - with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear_layers = [ - TestFP8Layer( - self.activation_quant_key, - self.weight_quant_key, - self.w[i], - self.wscale[i], - input_scale=self.scale[i], - ) - for i in range(3) - ] + self.fp8_linear_layers = [ + TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[i], + self.wscale[i], + input_scale=self.scale[i], + force_kernel=force_kernel, + ) + for i in range(3) + ] self.enable_rms_norm_custom_op = self.norm[0].enabled() self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ @@ -128,6 +150,21 @@ def ops_in_model_before_partial(self): ) +ROCM_FP8_KERNELS = [ + ROCmScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, +] + +CUDA_FP8_KERNELS = [ + FlashInferScaledMMLinearKernel, + CutlassFP8ScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, +] + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @@ -135,10 +172,8 @@ def ops_in_model_before_partial(self): @pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) -# cuda_force_torch used to test torch code path on platforms that -# cutlass_fp8_supported() == True. @pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] + "force_kernel", CUDA_FP8_KERNELS if current_platform.is_cuda() else ROCM_FP8_KERNELS ) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" @@ -151,7 +186,7 @@ def test_fusion_rmsnorm_quant( static, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, - cuda_force_torch, + force_kernel, ): torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -179,8 +214,12 @@ def test_fusion_rmsnorm_quant( backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass) - model = TestModel(hidden_size, eps, static, cuda_force_torch) + model = TestModel(hidden_size, eps, static, force_kernel) + # skip the test if we cannot force the kernel + selected_kernels = [layer.kernel for layer in model.fp8_linear_layers] + if not any(isinstance(kernel, force_kernel) for kernel in selected_kernels): + pytest.skip(f"{force_kernel.__name__} couldn't be forced") # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) diff --git a/tests/utils.py b/tests/utils.py index ba28886e6079..8fac003b8e9b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,6 +45,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform @@ -1443,6 +1446,7 @@ def __init__( weight_scale: torch.Tensor, input_scale: torch.Tensor, out_dtype: torch.dtype | None = None, + force_kernel: FP8ScaledMMLinearKernel | None = None, ): super().__init__() self.weight_scale = weight_scale @@ -1454,7 +1458,7 @@ def __init__( activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, out_dtype=out_dtype, - module_name=self.__class__.__name__, + force_kernel=force_kernel, ) def is_quant_fp8_enabled(self) -> bool: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 4a3f74f59126..b033cc7905e4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -61,7 +61,6 @@ FlashInferScaledMMLinearKernel, CutlassFP8ScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, - RowWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel, ], PlatformEnum.ROCM: [ @@ -76,10 +75,38 @@ _KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) +def can_implement_scaled_mm_linear_kernel( + kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None +) -> tuple[bool, str]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): + return False, f" {kernel.__name__} disabled by environment variable" + + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if ( + kernel_min_capability is not None + and kernel_min_capability > compute_capability + ): + return ( + False, + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}", + ) + can_implement, failure_reason = kernel.can_implement(config) + if not can_implement: + return (False, f" {kernel.__name__} cannot implement due to: {failure_reason}") + + return True, "" + + def choose_scaled_mm_linear_kernel( config: _KernelConfigT, possible_kernels: dict[PlatformEnum, list[type[_KernelT]]], compute_capability: int | None = None, + force_kernel: type[_KernelT] | None = None, ) -> type[_KernelT]: """ Choose a _KernelT that can implement the given config for the @@ -94,6 +121,9 @@ def choose_scaled_mm_linear_kernel( compute_capability (Optional[int], optional): The compute capability of the target device, if None uses `current_platform` to get the compute capability. Defaults to None. + force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override + the possible_kernels if it can be implemented. If None, it will only try the + possible kernels. Raises: ValueError: If no kernel can implement the given config. @@ -107,40 +137,32 @@ def choose_scaled_mm_linear_kernel( if _cc is not None: compute_capability = _cc[0] * 10 + _cc[1] - failure_reasons = [] + failure_reason_list = [] + + if force_kernel is not None: + can_implement, failure_reason = can_implement_scaled_mm_linear_kernel( + force_kernel, config, compute_capability + ) + if can_implement: + return force_kernel + + logger.info_once( + "Tried to force %s, but the kernel couldn't be implemented", + force_kernel.__name__, + scope="global", + ) + for kernel in possible_kernels[current_platform._enum]: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): - failure_reasons.append( - f" {kernel.__name__} disabled by environment variable" - ) - continue - - # If the current platform uses compute_capability, - # make sure the kernel supports the compute cability. - if compute_capability is not None: - kernel_min_capability = kernel.get_min_capability() - if ( - kernel_min_capability is not None - and kernel_min_capability > compute_capability - ): - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel_min_capability}, current compute capability " - f"is {compute_capability}" - ) - continue - - can_implement, failure_reason = kernel.can_implement(config) + can_implement, failure_reason = can_implement_scaled_mm_linear_kernel( + kernel, config, compute_capability + ) if can_implement: return kernel - else: - failure_reasons.append( - f" {kernel.__name__} cannot implement due to: {failure_reason}" - ) + failure_reason_list.append(failure_reason) raise ValueError( "Failed to find a kernel that can implement the " - "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) + "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list) ) @@ -148,7 +170,8 @@ def init_fp8_linear_kernel( activation_quant_key: QuantKey, weight_quant_key: QuantKey, out_dtype: torch.dtype, - module_name: str, + force_kernel: type[FP8ScaledMMLinearKernel] | None = None, + module_name: str | None = None, ) -> FP8ScaledMMLinearKernel: scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( weight_quant_key=weight_quant_key, @@ -157,16 +180,16 @@ def init_fp8_linear_kernel( ) kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_FP8_KERNELS, + scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel ) - logger.info_once( - "Selected %s for %s", - kernel_type.__name__, - module_name, - scope="global", - ) + if module_name: + logger.info_once( + "Selected %s for %s", + kernel_type.__name__, + module_name, + scope="global", + ) return kernel_type( scaled_mm_linear_kernel_config, From edb6d43a371ea3a7e425c273950ab4af1ccff0f8 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 12:14:54 +0000 Subject: [PATCH 30/36] ensure static scales for ChannelWiseTorchScaledMMLinearKernel; remove comment Signed-off-by: vllmellm --- .../layers/quantization/kernels/scaled_mm/pytorch.py | 9 +++++++++ .../layers/quantization/kernels/scaled_mm/rocm.py | 1 - 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 8c0f0e1d57fb..1736f145de02 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -209,10 +209,19 @@ def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + is_static = c.activation_quant_key.scale.static + per_tensor_activation_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor() ) per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not is_static: + return ( + False, + "ChannelWiseTorchScaledMMLinearKernel requires static scales", + ) + if per_tensor_activation_scales and per_tensor_weight_scales: return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index bcc92ef209af..493507ba4313 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -89,7 +89,6 @@ def rocm_per_tensor_float_w8a8_scaled_mm( class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - # TODO: check if this causes an issue on non-ROCM platforms from vllm.platforms.rocm import on_mi3xx per_tensor_activation_scales = ( From 45a3008ceb0b9f55b23fdc3dc8d4f4be480b86aa Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 10 Nov 2025 08:02:17 +0000 Subject: [PATCH 31/36] feat: Integrate AITER bpreshuffle and ck operators on top of fp8 refactor Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 2 + .../kernels/scaled_mm/__init__.py | 4 + .../quantization/kernels/scaled_mm/aiter.py | 217 +++++++++++++++++- 3 files changed, 222 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 2cd29e0905d0..e25d2aaa439b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -192,6 +192,8 @@ def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.BLOCK: maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + self.fp8_linear.process_weights_after_loading(layer) + def apply_weights( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index b033cc7905e4..b8c7f78aac64 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -8,6 +8,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterBpreshufflePerTokenFp8ScaledMMLinearKernel, + AiterCKPerTokenFp8ScaledMMLinearKernel, AiterScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( @@ -64,6 +66,8 @@ ChannelWiseTorchScaledMMLinearKernel, ], PlatformEnum.ROCM: [ + AiterBpreshufflePerTokenFp8ScaledMMLinearKernel, + AiterCKPerTokenFp8ScaledMMLinearKernel, ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 3ac90553bbc7..430e407156c5 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -2,15 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch +from aiter.ops.shuffle import shuffle_weight import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearLayerConfig, +) + +logger = init_logger(__name__) def rocm_aiter_gemm_w8a8_impl( @@ -52,6 +62,54 @@ def rocm_aiter_gemm_w8a8_fake( ) +# bpshuffle +def rocm_aiter_gemm_a8w8_bpreshuffle_impl( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + # This AITER function can be used for + # - per-token activations + per-channel weights + # e.g. vllm/model_executor/layers/quantization/utils/w8a8_utils.py + # accept the weight as # keep the weight as (N, K) + # NOTE: The weight has to be shuffled in the + # process_weights_after_loading of the CompressedTensorsW8A8Fp8 class + + from aiter import gemm_a8w8_bpreshuffle_ck + + m = input.shape[0] + n = weight.shape[0] + Y = torch.empty(m, n, dtype=out_dtype, device=input.device) + gemm_a8w8_bpreshuffle_ck(input, weight, scale_a, scale_b, Y) + return Y + + +def rocm_aiter_gemm_a8w8_bpreshuffle_fake( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + m = input.shape[0] + n = weight.shape[0] + if out_dtype is None: + out_dtype = input.dtype + return torch.empty((m, n), dtype=out_dtype, device=input.device) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_gemm_a8w8_bpreshuffle", + op_func=rocm_aiter_gemm_a8w8_bpreshuffle_impl, + mutates_args=[], + fake_impl=rocm_aiter_gemm_a8w8_bpreshuffle_fake, + dispatch_key=current_platform.dispatch_key, + ) + + class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -157,3 +215,160 @@ def apply_weights( return torch.ops.vllm.rocm_aiter_gemm_w8a8( x_q, w_q.t(), x_s, w_s, bias, out_dtype ) + + +# bpreshuffle +class AiterBpreshufflePerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + def get_ouput_padding(self) -> int | None: + # PTPC kernels do not require padding. + return None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return (False, "AITER bpreshuffle is ROCm-only") + if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + return (False, "AITER bpreshuffle is disabled by env var") + try: + import aiter # noqa: F401 + except Exception: + return (False, "AITER not installed") + + # Check if the configuration is PTPC + is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() + is_per_token_activation = ( + c.activation_quant_key.scale.group_shape.is_per_token() + ) + is_ptpc = is_per_channel_weight and is_per_token_activation + + logger.info_once(f"AiterBpreshuffle: can_implement called. is_ptpc={is_ptpc}") + + if not is_ptpc: + return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info_once("AiterBpreshuffle: SHUFFLING WEIGHTS NOW.") + + w_q, _, _, _ = self._get_layer_params(layer) + + N = w_q.shape[1] + K = w_q.shape[0] + + if N % 16 == 0 and K % 16 == 0: + # AITER shuffle_weight expectation [N, K] + w_q_nk = w_q.t().contiguous() + + # Execute shuffle + shuffled_w_nk = shuffle_weight(w_q_nk, layout=(16, 16)) + + del layer.weight + layer.register_buffer("weight", shuffled_w_nk) + + logger.info_once("[AiterBpreshuffle: Weight shuffle COMPLETE.") + + else: + raise ValueError( + f"Weight shape (N={N}, K={K}) not divisible by 16 " + "for AITER bpreshuffle." + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + # 1. Obtain parameters + w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) + # 2. Dynamic quantization input + qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) + + logger.info_once( + "AiterBpreshuffle: apply_weights... ABOUT TO CALL C++ KERNEL..." + ) + + # 3. Call the AITER bpreshuffle CK operator. + output = torch.ops.vllm.rocm_aiter_gemm_a8w8_bpreshuffle( + qinput, + w_q, # Input [N, K] shuffle weights + out_dtype=self.config.out_dtype, + scale_a=qinput_scale, + scale_b=w_s, + ) + + logger.info_once("AiterBpreshuffle: C++ KERNEL CALL SUCCEEDED.") + + if bias is not None: + output.add_(bias) + return output + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_aiter_gemm_a8w8_bpreshuffle_impl + + +# AITER FP8 CK +class AiterCKPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + AITER PTPC kernel (gemm_a8w8_CK) without pre-shuffling. + """ + + def get_ouput_padding(self) -> int | None: + return None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return (False, "AITER CK is ROCm-only") + if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + return (False, "AITER CK is disabled by env var") + try: + import aiter # noqa: F401 + except Exception: + return (False, "AITER not installed") + + is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() + is_per_token_activation = ( + c.activation_quant_key.scale.group_shape.is_per_token() + ) + is_ptpc = is_per_channel_weight and is_per_token_activation + + logger.info_once(f"AiterCK: can_implement called. is_ptpc={is_ptpc}") + + if not is_ptpc: + return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info_once( + "AITER CK: process_weights_after_loading... DOING NOTHING (pass)." + ) + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) + + qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) + + logger.info_once( + "AiterCK: apply_weights... " + "ABOUT TO CALL C++ KERNEL (this is where it hangs)..." + ) + + output = torch.ops.vllm.rocm_aiter_gemm_w8a8( + qinput, w_q.t(), qinput_scale, w_s, bias, self.config.out_dtype + ) + + logger.info_once("AiterCK: C++ KERNEL CALL SUCCEEDED.") + + return output + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_aiter_gemm_w8a8_impl From 858765f59f7485cf458084ad6a64a5e018f44d3a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 10 Nov 2025 11:17:46 +0000 Subject: [PATCH 32/36] fix output padding for torch _scaled_mm Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../quantization/kernels/scaled_mm/pytorch.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 2cd29e0905d0..dc735a115e0b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -88,7 +88,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, - out_dtype=torch.get_default_dtype(), + out_dtype=self.out_dtype, module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 1736f145de02..c272f579d8bc 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -44,7 +44,7 @@ def torch_per_tensor_w8a8_scaled_mm( if type(output) is tuple and len(output) == 2: output = output[0] - return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) + return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape) def torch_row_wise_w8a8_scaled_mm( @@ -77,7 +77,7 @@ def torch_row_wise_w8a8_scaled_mm( bias=bias, ) - output = torch.narrow(output, 0, 0, A.shape[0]) + output = torch.narrow(output, 0, 0, output_shape[0]) output = output.view(*output_shape) return output @@ -121,8 +121,8 @@ def torch_channelwise_w8a8_scaled_mm( if type(output) is tuple and len(output) == 2: output = output[0] # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, A.shape[0]) - x_scale = torch.narrow(As, 0, 0, A.shape[0]) + output = torch.narrow(output, 0, 0, output_shape[0]) + x_scale = torch.narrow(As, 0, 0, output_shape[0]) # DQ # C = sw * sx * (X * W) + bias @@ -142,6 +142,11 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): """ def get_ouput_padding(self) -> int | None: + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. vllm_config = get_current_vllm_config().compilation_config pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE output_padding = 17 if pad_output else None From 65ecf487ad134e521cc6fe93a370fbaa5d989d92 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 12 Nov 2025 07:43:16 +0000 Subject: [PATCH 33/36] optional input scales Signed-off-by: vllmellm --- tests/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 1fc8b260d1e9..848c4efa8bcd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1325,7 +1325,8 @@ class TestFP8Layer(torch.nn.Module): weight_quant_key (QuantKey): Key for weight quantization configuration. weight (torch.Tensor): Weight tensor for linear transformation. weight_scale (torch.Tensor): Per-tensor or per-group scale for weights. - input_scale (torch.Tensor): Scale tensor for input quantization. + input_scale (torch.Tensor, optional): Scale tensor for input quantization. + Defaults to None. out_dtype (torch.dtype, optional): Output tensor data type. Defaults to torch.get_default_dtype(). """ @@ -1336,7 +1337,7 @@ def __init__( weight_quant_key: QuantKey, weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, force_kernel: FP8ScaledMMLinearKernel | None = None, ): From 405d2802c68237703068e1ae6f157fc5c4683c2f Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 12 Nov 2025 15:52:49 +0000 Subject: [PATCH 34/36] remove maybe_create_device_identity Signed-off-by: vllmellm --- tests/compile/test_fusion.py | 4 --- tests/compile/test_silu_mul_quant_fusion.py | 4 --- .../schemes/compressed_tensors_w8a8_fp8.py | 3 -- .../layers/quantization/fbgemm_fp8.py | 2 -- .../model_executor/layers/quantization/fp8.py | 3 -- .../quantization/kernels/scaled_mm/pytorch.py | 30 +++++-------------- .../layers/quantization/utils/w8a8_utils.py | 22 -------------- 7 files changed, 7 insertions(+), 61 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aa4d2c8cf453..9a5af9f1d245 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -40,9 +40,6 @@ QuantKey, ScaleDesc, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - maybe_create_device_identity, -) from vllm.platforms import current_platform from ..utils import TestFP8Layer @@ -191,7 +188,6 @@ def test_fusion_rmsnorm_quant( torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) - maybe_create_device_identity() # needed for certain non-cutlass fp8 paths custom_ops = [] if enable_rms_norm_custom_op: diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 56b36856f7f2..c4f6f2d4c4d9 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -28,9 +28,6 @@ kFp8StaticTensorSym, kNvfp4Quant, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - maybe_create_device_identity, -) from vllm.platforms import current_platform from ..utils import TestFP8Layer, override_cutlass_fp8_supported @@ -157,7 +154,6 @@ def test_fusion_silu_and_mul_quant( torch.set_default_device("cuda") torch.set_default_dtype(dtype) - maybe_create_device_identity() x = torch.rand(num_tokens, hidden_size * 2) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 01204b10ea18..5480383126e3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -34,7 +34,6 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, - maybe_create_device_identity, ) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, @@ -108,8 +107,6 @@ def create_weights( weight_loader: Callable, **kwargs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.weight_block_size = None diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index bcd02554008c..45d2e4e33819 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -31,7 +31,6 @@ kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, ) from vllm.model_executor.parameter import ( @@ -112,7 +111,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - maybe_create_device_identity() weight_loader = extra_weight_attrs.get("weight_loader") del input_size, output_size output_size_per_partition = sum(output_partition_sizes) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 91115d7437e2..57b3736ed7fd 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -86,7 +86,6 @@ all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, ) @@ -416,8 +415,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index c272f579d8bc..ad21d68c0f52 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -14,17 +14,6 @@ FP8ScaledMMLinearLayerConfig, ) -# Input scaling factors are no longer optional in _scaled_mm starting -# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = None - - -def maybe_create_device_identity(): - # Allocate dummy ones tensor for torch._scaled_mm - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) - def torch_per_tensor_w8a8_scaled_mm( *, @@ -57,8 +46,7 @@ def torch_row_wise_w8a8_scaled_mm( bias: torch.Tensor, output_shape: list, ) -> torch.Tensor: - # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM - # when using it. + # Note: # For now it has only been validated on ROCm platform. # fp8 rowwise scaling in torch._scaled_mm is introduced in # https://github.com/pytorch/pytorch/pull/144432 using @@ -106,14 +94,18 @@ def torch_channelwise_w8a8_scaled_mm( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as scales + dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device) + # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place output = torch._scaled_mm( A, B, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, + scale_a=dummy_tensor, + scale_b=dummy_tensor, out_dtype=torch.float32, ) # A fix for discrepancy in scaled_mm which returns tuple @@ -214,19 +206,11 @@ def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - is_static = c.activation_quant_key.scale.static - per_tensor_activation_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor() ) per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() - if not is_static: - return ( - False, - "ChannelWiseTorchScaledMMLinearKernel requires static scales", - ) - if per_tensor_activation_scales and per_tensor_weight_scales: return ( False, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index f2d8eecdc68e..c7fcb5a4b33b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -3,25 +3,10 @@ import torch -from packaging import version from vllm import _custom_ops as ops from vllm.platforms import current_platform -# Input scaling factors are no longer optional in _scaled_mm starting -# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = None - -# The condition to determine if it is on a platform that supports -# torch._scaled_mm rowwise feature. -# The condition is determined once as the operations -# are time-consuming. -USE_ROWWISE_TORCH_SCALED_MM = ( - current_platform.is_rocm() - and version.parse(torch.__version__) >= version.parse("2.7") - and current_platform.has_device_capability(94) -) - def sparse_cutlass_supported() -> bool: if not current_platform.is_cuda(): @@ -129,13 +114,6 @@ def requantize_with_max_scale( return max_w_scale, weight -def maybe_create_device_identity(): - # Allocate dummy ones tensor for torch._scaled_mm - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, From 10eebd48961cd239e5e5ac983043c72d7127dac8 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 14 Nov 2025 07:46:37 +0000 Subject: [PATCH 35/36] add CPU kernels; fix fp8 quant type selection Signed-off-by: vllmellm --- vllm/model_executor/layers/quantization/fp8.py | 7 +++---- .../layers/quantization/kernels/scaled_mm/__init__.py | 4 ++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a0159c39092c..67e5b65de601 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -77,9 +77,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, - kFp8DynamicTensorSym, + kFp8DynamicTokenSym, kFp8StaticTensorSym, - kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, @@ -381,10 +380,10 @@ def __init__(self, quant_config: Fp8Config): # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): self.act_q_group_shape = GroupShape.PER_TOKEN - self.activation_quant_key = kFp8StaticTokenSym + self.activation_quant_key = kFp8DynamicTokenSym else: self.act_q_group_shape = GroupShape.PER_TENSOR - self.activation_quant_key = kFp8DynamicTensorSym + self.activation_quant_key = kFp8StaticTensorSym if self.block_quant: assert not self.act_q_static diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index b033cc7905e4..36e4a16c0168 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -69,6 +69,10 @@ RowWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel, ], + PlatformEnum.CPU: [ + PerTensorTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], } _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) From 679a7cffdc3baf2a2f205d993a60a8925ebfd358 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 17 Nov 2025 06:29:29 +0000 Subject: [PATCH 36/36] WIP: Integrate Aiter bpreshuffle and ck kernels Signed-off-by: vllmellm --- vllm/_aiter_ops.py | 56 ++++++ .../kernels/scaled_mm/__init__.py | 4 + .../quantization/kernels/scaled_mm/aiter.py | 165 ++++++++++++++++++ 3 files changed, 225 insertions(+) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 5508e59bcd2f..6de21176e948 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -402,6 +402,42 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( return torch.empty_like(x), torch.empty_like(residual) +def _rocm_aiter_gemm_a8w8_bpreshuffle_impl( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + # This AITER function can be used for + # - per-token activations + per-channel weights + # accept the weight as # keep the weight as (N, K) + # NOTE: The weight has to be shuffled in the + # process_weights_after_loading of the CompressedTensorsW8A8Fp8 class + + from aiter import gemm_a8w8_bpreshuffle_ck + + m = input.shape[0] + n = weight.shape[0] + Y = torch.empty(m, n, dtype=out_dtype, device=input.device) + gemm_a8w8_bpreshuffle_ck(input, weight, scale_a, scale_b, Y) + return Y + + +def _rocm_aiter_gemm_a8w8_bpreshuffle_fake( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + m = input.shape[0] + n = weight.shape[0] + if out_dtype is None: + out_dtype = input.dtype + return torch.empty((m, n), dtype=out_dtype, device=input.device) + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -592,6 +628,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_gemm_a8w8_bpreshuffle", + op_func=_rocm_aiter_gemm_a8w8_bpreshuffle_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_a8w8_bpreshuffle_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod @@ -635,6 +679,18 @@ def gemm_a8w8_blockscale( A, B, As, Bs, output_dtype ) + @staticmethod + def gemm_a8w8_bpreshuffle( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_a8w8_bpreshuffle( + input, weight, out_dtype, scale_a, scale_b + ) + @staticmethod def fused_moe( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 36e4a16c0168..90cbda90adf9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -8,6 +8,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterBpreshufflePerTokenFp8ScaledMMLinearKernel, + AiterCKPerTokenFp8ScaledMMLinearKernel, AiterScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( @@ -64,6 +66,8 @@ ChannelWiseTorchScaledMMLinearKernel, ], PlatformEnum.ROCM: [ + AiterBpreshufflePerTokenFp8ScaledMMLinearKernel, + AiterCKPerTokenFp8ScaledMMLinearKernel, ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 4a1c76ffd9b1..28c5640d319a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -2,17 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch +from aiter.ops.shuffle import shuffle_weight from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import init_logger from vllm.platforms import current_platform from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, Int8ScaledMMLinearLayerConfig, ) +logger = init_logger(__name__) + class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod @@ -117,3 +125,160 @@ def apply_weights( # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) + + +class AiterBpreshufflePerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + def get_ouput_padding(self) -> int | None: + # PTPC kernels do not require padding. + return None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return (False, "AITER bpreshuffle is ROCm-only") + + if not rocm_aiter_ops.is_linear_enabled(): + return (False, "AITER bpreshuffle is disabled by env var") + + try: + import aiter # noqa: F401 + except Exception: + return (False, "AITER not installed") + + # Check if the configuration is PTPC + is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() + is_per_token_activation = ( + c.activation_quant_key.scale.group_shape.is_per_token() + ) + is_ptpc = is_per_channel_weight and is_per_token_activation + + logger.info_once(f"AiterBpreshuffle: can_implement called. is_ptpc={is_ptpc}") + + if not is_ptpc: + return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info_once("AiterBpreshuffle: SHUFFLING WEIGHTS NOW.") + + w_q, _, _, _ = self._get_layer_params(layer) + + N = w_q.shape[1] + K = w_q.shape[0] + + if N % 16 == 0 and K % 16 == 0: + # AITER shuffle_weight expectation [N, K] + w_q_nk = w_q.t().contiguous() + + # Execute shuffle + shuffled_w_nk = shuffle_weight(w_q_nk, layout=(16, 16)) + + del layer.weight + layer.register_buffer("weight", shuffled_w_nk) + + logger.info_once("[AiterBpreshuffle: Weight shuffle COMPLETE.") + + else: + raise ValueError( + f"Weight shape (N={N}, K={K}) not divisible by 16 " + "for AITER bpreshuffle." + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + # 1. Obtain parameters + w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) + # 2. Dynamic quantization input + qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) + + logger.info_once( + "AiterBpreshuffle: apply_weights... ABOUT TO CALL C++ KERNEL..." + ) + + output = rocm_aiter_ops.gemm_a8w8_bpreshuffle( + qinput, + w_q, # Input [N, K] shuffle weights + out_dtype=self.config.out_dtype, + scale_a=qinput_scale, + scale_b=w_s, + ) + + logger.info_once("AiterBpreshuffle: C++ KERNEL CALL SUCCEEDED.") + + if bias is not None: + output.add_(bias) + return output + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_aiter_ops.gemm_a8w8_bpreshuffle + + +class AiterCKPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + AITER PTPC kernel (gemm_a8w8_CK) without pre-shuffling. + """ + + def get_ouput_padding(self) -> int | None: + return None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return (False, "AITER CK is ROCm-only") + + if not rocm_aiter_ops.is_linear_enabled(): + return (False, "AITER CK is disabled by env var") + + try: + import aiter # noqa: F401 + except Exception: + return (False, "AITER not installed") + + is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() + is_per_token_activation = ( + c.activation_quant_key.scale.group_shape.is_per_token() + ) + is_ptpc = is_per_channel_weight and is_per_token_activation + + logger.info_once(f"AiterCK: can_implement called. is_ptpc={is_ptpc}") + + if not is_ptpc: + return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info_once( + "AITER CK: process_weights_after_loading... DOING NOTHING (pass)." + ) + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) + + qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) + + logger.info_once( + "AiterCK: apply_weights... " + "ABOUT TO CALL C++ KERNEL (this is where it hangs)..." + ) + + output = rocm_aiter_ops.gemm_a8w8( + qinput, w_q.t(), qinput_scale, w_s, bias, self.config.out_dtype + ) + + logger.info_once("AiterCK: C++ KERNEL CALL SUCCEEDED.") + return output + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_aiter_ops.gemm_a8w8