diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index bc348df84d01..69bced7c0b8e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -26,9 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -74,7 +75,7 @@ def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + return [torch.float32, torch.float16, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: @@ -299,6 +300,22 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, # Only symmetric weight quantization supported. return is_8_bits and is_token and weight_quant.symmetric and is_dynamic + def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + is_weight_4_bits = weight_quant.num_bits == 4 + is_activation_8_bits = input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return (is_weight_4_bits and is_activation_8_bits and is_token + and weight_quant.symmetric and is_dynamic) + def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: # Confirm weights and activations quantized. @@ -374,7 +391,6 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": - # Detect If Mixed Precision if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() @@ -443,6 +459,16 @@ def _get_scheme_from_parts( is_static_input_scheme=False, input_symmetric=input_quant.symmetric) + if self._is_dynamic_token_w4a8_int(weight_quant, input_quant): + is_static_input_scheme = (input_quant + and not input_quant.dynamic) + return CompressedTensorsW4A8Int( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + group_size=weight_quant.group_size, + is_static_input_scheme=is_static_input_scheme, + input_symmetric=input_quant.symmetric) + raise NotImplementedError( "No compressed-tensors compatible scheme was found.") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 6e4e75df7604..734fa603ba7b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,6 +3,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 @@ -20,5 +21,5 @@ "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Fp4" + "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py new file mode 100644 index 000000000000..f1fca85508a6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + ModelWeightParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A8Int"] +W4A8_SUPPORTED_TYPES_MAP = { + 4: scalar_types.int4, +} +W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A8Int(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + is_static_input_scheme: bool = False, + input_symmetric: bool = True): + self.strategy = strategy + self.group_size = -1 if group_size is None else group_size + self.is_static_input_scheme = is_static_input_scheme + self.input_symmetric = input_symmetric + + if num_bits not in W4A8_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}." + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + return 1 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + row_parallel = (input_size != input_size_per_partition) + + # Compute effective group_size + if self.group_size == -1: + effective_group_size = (input_size_per_partition + if row_parallel else input_size) + else: + effective_group_size = self.group_size + + # Ensure group_size divides input_size_per_partition + assert input_size_per_partition % effective_group_size == 0, ( + f"input_size_per_partition {input_size_per_partition}" + f" not divisible by group_size {effective_group_size}") + + # Determine scale partitioning + is_channelwise = (self.group_size == -1) + repeat_scales = (is_channelwise and row_parallel) + partition_scales = not repeat_scales + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=(input_size_per_partition, + output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=effective_group_size, + zero_points=False, + has_g_idx=False, + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW4A8Int", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + scales_and_zp_size = input_size_per_partition // effective_group_size + + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty(output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype) + } + + if partition_scales: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + else: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name=None) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 21e5ae793c3f..a5084f6ee92c 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -10,6 +10,8 @@ BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 ConchLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501 + Dynamic4bitLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 @@ -25,6 +27,7 @@ MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, + Dynamic4bitLinearKernel, BitBLASLinearKernel, ConchLinearKernel, ExllamaLinearKernel, @@ -56,7 +59,8 @@ def choose_mp_linear_kernel( if current_platform is None: raise ValueError("Cannot determine compute capability") _cc = current_platform.get_device_capability() - compute_capability = _cc[0] * 10 + _cc[1] + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] for kernel in _POSSIBLE_KERNELS: @@ -64,12 +68,12 @@ def choose_mp_linear_kernel( failure_reasons.append( f' {kernel.__name__} disabled by environment variable') continue - - if kernel.get_min_capability() > compute_capability: + if (compute_capability is not None + and kernel.get_min_capability() > compute_capability): failure_reasons.append( f"{kernel.__name__} requires capability " - f"{kernel.get_min_capability()}, current compute capability " - f"is {compute_capability}") + f"{kernel.get_min_capability()}, current compute " + f" capability is {compute_capability}") continue can_implement, failure_reason = kernel.can_implement(config) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py new file mode 100644 index 000000000000..7bd326f47f9e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.platforms import CpuArchEnum, current_platform +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class Dynamic4bitLinearKernel(MPLinearKernel): + SUPPORTED_QUANT_TYPES = [scalar_types.int4] + + @classmethod + def get_min_capability(cls) -> int: + return 1 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cpu(): + return False, "Only CPU is supported" + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: + return False, f"Unsupported quant type {c.weight_type}" + if current_platform.get_cpu_architecture( + ) == CpuArchEnum.ARM and c.act_type not in [ + torch.float32, + ]: + return False, "Dynamic4bitLinearKernel on Arm requires"\ + " Float32 activations" + if c.full_weight_shape[0] % c.group_size != 0: + return False, f"Group size ({c.group_size}) does not evenly divide"\ + " the number of input features "\ + f"({c.full_weight_shape[0]})" + if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: + try: + # Attempt to retrieve the operation + _ = torch.ops.aten._dyn_quant_matmul_4bit + except AttributeError: + return False, f"PyTorch {torch.__version__} does not support"\ + " _dyn_quant_matmul_4bit. Install a newer version" + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + packed_weight = getattr(layer, self.w_q_name) + packed_weight = packed_weight.add(8) + uint8_packed = (packed_weight[::, 1::2] << 4 + | packed_weight[::, ::2]).to(torch.uint8) + + scales = getattr(layer, self.w_s_name) + block_size = c.group_size + + # Handle scaling factors for partitioned weights + if block_size == c.partition_weight_shape[0]: + scales = scales.to( + torch.float32 + ) # Float32 & Bfloat16 variants requires float32 scales + scales = scales.view(-1, 1) # Channel-wise scales + if layer.bias is not None: + layer.bias = layer.bias.to( + torch.float32 + ) # Float32 & Bfloat16 variants requires float32 bias + else: + # KleidiAI kernel requires bfloat16 scales with groupwise scheme + scales = scales.to(torch.bfloat16) + + # Repack weights as per kernel requirement + w = torch.ops.aten._dyn_quant_pack_4bit_weight( + uint8_packed, scales, layer.bias, block_size, + c.partition_weight_shape[0], c.partition_weight_shape[1]) + replace_parameter(layer, self.w_q_name, + torch.nn.Parameter(w, requires_grad=False)) + setattr(layer, self.w_s_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + w_q = getattr(layer, self.w_q_name) + output = torch.ops.aten._dyn_quant_matmul_4bit( + x_2d, w_q, c.group_size, c.partition_weight_shape[0], + c.partition_weight_shape[1]) + return output.reshape(out_shape)