From 4d27a2c4680a40fc5509a8ec006b9e2e7c40b80f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 13:10:42 -0400 Subject: [PATCH 1/6] Initial `CompressedTensors` config + Activation Quantization support for static W8A8 per tensor (#195) - Depending on how we end up parsing `ignore` and `targets` (layer_name vs layer_type) we may not need layer_name to be added to the linear_method. Will experiment using a compressed-tensors function in a follow-up PR - Initial implementation for Compressed Config support + Activation Quantization for static per tensor w8a8 - Includes fused kernels added by @varun-sundar-rabindranath ```python from vllm import LLM, SamplingParams import torch prompts = [ "Hello, my name is", "The capital of France is", "The US president is", "The future of AI is" ] sampling_params = SamplingParams(temperature=0.80, top_p=0.95) llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml") outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - Verification of the different inputs expected for `targets` and `ignore` --> use functions to parse the layer names which can be shared by sparseml and vllm; would live in compressed tensors (https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86) - Updates to further optimize fake qunat --------- Co-authored-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- CMakeLists.txt | 1 + csrc/ops.h | 5 + csrc/pybind.cpp | 9 + .../compressed_tensors/int8_quant_kernels.cu | 50 ++++++ requirements-cuda.txt | 1 + vllm/model_executor/layers/linear.py | 169 ++++++++++++++---- .../layers/quantization/__init__.py | 5 +- .../layers/quantization/aqlm.py | 1 + .../model_executor/layers/quantization/awq.py | 1 + .../compressed_tensors/__init__.py | 0 .../compressed_tensors/compressed_tensors.py | 159 ++++++++++++++++ .../compressed_tensors/cutlass_gemm.py | 91 ++++++++++ .../compressed_tensors/schemes/__init__.py | 3 + .../schemes/compressed_tensors_scheme.py | 32 ++++ .../schemes/compressed_tensors_unquantized.py | 36 ++++ .../compressed_tensors_w8a8_statictensor.py | 137 ++++++++++++++ .../model_executor/layers/quantization/fp8.py | 1 + .../layers/quantization/gptq.py | 1 + .../layers/quantization/marlin.py | 1 + .../layers/quantization/squeezellm.py | 1 + vllm/model_executor/models/llama.py | 39 ++-- vllm/worker/model_runner.py | 2 +- 22 files changed, 691 insertions(+), 54 deletions(-) create mode 100644 csrc/quantization/compressed_tensors/int8_quant_kernels.cu create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/__init__.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e9262b57d086..261e57274f8c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/fp8_cuda_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 03bb1e24dc68..823dabf90c30 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -156,6 +156,11 @@ void dynamic_scaled_fp8_quant( torch::Tensor& input, torch::Tensor& scale); +void quant_per_tensor( + torch::Tensor& out, + torch::Tensor& input, + float scale); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 2250c7f69f0a..13514065456c 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -80,6 +80,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &moe_align_block_size, "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + ops.def( + "quant_per_tensor", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + float>(&quant_per_tensor), + "Per-tensor Quantization"); + + // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def( diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu new file mode 100644 index 000000000000..e1af55dc225a --- /dev/null +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -0,0 +1,50 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} + +namespace vllm { + +template +__global__ void quant_kernel( + const scalar_t* __restrict__ input, + int8_t* __restrict__ out, + scale_type scale, + const int hidden_size) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + } +} +} // namespace vllm + +void quant_per_tensor( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { + vllm::quant_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + scale, + hidden_size); + }); +} diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 1bddae4c6f40..f4c04afd55c7 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 +nvidia-cutlass == 3.5.0 diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4d43ed4c5f14..5469898972e4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ -from abc import abstractmethod -from typing import List, Optional +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -30,11 +30,13 @@ class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + layer_name: Optional[str] = None, + **extra_weight_attrs) -> Dict[str, Any]: + """Create weights for a linear layer. The weights will be set as attributes of the layer. @@ -47,6 +49,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: Size of the input dim of the weight across all ranks. output_size: Size of the output dim of the weight across all ranks. params_dtype: Datatype of the parameters. + layer_name: name of the layer in the state dict. """ raise NotImplementedError @@ -56,7 +59,6 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights in layer to the input tensor. - Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -76,9 +78,9 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - output_size_per_partition = sum(output_partition_sizes) - weight = Parameter(torch.empty(output_size_per_partition, + layer_name: Optional[str] = None, + **extra_weight_attrs) -> Dict[str, Any]: + weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), requires_grad=False) @@ -108,6 +110,7 @@ class LinearBase(torch.nn.Module): skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + layer_name: name of the layer in the state dict. """ def __init__( @@ -117,10 +120,12 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add @@ -157,15 +162,16 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, layer_name) # All the linear layer supports quant method. assert self.quant_method is not None self.quant_method.create_weights(self, self.input_size, [self.output_size], self.input_size, - self.output_size, self.params_dtype) + self.output_size, self.params_dtype, layer_name=self.layer_name) if bias: self.bias = Parameter( @@ -202,6 +208,7 @@ class ColumnParallelLinear(LinearBase): quant_config: Quantization configure. output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. + layer_name: name of the layer in the state dict. """ def __init__( @@ -214,6 +221,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, + layer_name: Optional[str] = None ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -222,18 +230,27 @@ def __init__( # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) + for output_size in self.output_sizes + ] + if output_sizes is None: output_sizes = [output_size] - # All the linear layer supports quant method. - assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size, - [x // tp_size for x in output_sizes], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + layer=self, + layer_name=self.layer_name, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -302,13 +319,19 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, quant_config, - self.output_sizes) + super().__init__(layer_name=layer_name, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -318,6 +341,19 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) is_metadata = getattr(param, "is_metadata", False) + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -375,6 +411,12 @@ def weight_loader(self, shard_size = loaded_weight.shape[0] shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths") + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -382,6 +424,13 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -408,6 +457,7 @@ class QKVParallelLinear(ColumnParallelLinear): skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + layer_name: name of the layer in the state dict. """ def __init__( @@ -420,6 +470,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): self.hidden_size = hidden_size self.head_size = head_size @@ -440,14 +491,20 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - output_sizes = [ - self.num_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, quant_config, output_sizes) + super().__init__(layer_name=layer_name, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -456,6 +513,18 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) is_metadata = getattr(param, "is_metadata", False) + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) if loaded_shard_id is None: # Loaded weight is already packed. @@ -491,11 +560,14 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 shard_size = self.num_heads * self.head_size elif loaded_shard_id == "k": + # shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": @@ -529,6 +601,12 @@ def weight_loader(self, shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths") + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -536,6 +614,13 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") + + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -564,6 +649,7 @@ class RowParallelLinear(LinearBase): We skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + layer_name: name of the layer in the state dict. """ def __init__( @@ -576,9 +662,10 @@ def __init__( params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, layer_name) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -586,16 +673,16 @@ def __init__( # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) - # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size_per_partition, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) - + self.quant_method.create_weights( + layer=self, + layer_name=self.layer_name, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -619,6 +706,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 70e0a7cfe3e3..06fb0c905623 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,7 +4,9 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.fp8 import FP8Config +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig @@ -16,6 +18,7 @@ "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "marlin": MarlinConfig, + "sparseml": CompressedTensorsConfig } diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 83e24fadc140..6edb3c3e9c63 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -231,6 +231,7 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs): del output_size # Unused. del input_size # Unused. diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index f4fc7ce020e9..00b4a4714be1 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -86,6 +86,7 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 000000000000..a61bec6e0323 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,159 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig) + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsW8A8StaticTensor, CompressedTensorsUnquantized, + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str], + fake_quant: bool): + self.fake_quant = fake_quant + self.ignore = ignore + self.layer_quant_details = layer_quant_details + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float32, torch.int8] + + # Need to figure it out + def get_min_capability(self) -> int: + return 60 + + def get_name(self) -> str: + return "compressed_tensors" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + layer_quant_details: Dict[str:Any] = dict() + ignore = config.get("ignore") + fake_quant = config.get("format") == "fakequant" + + for key, quant_config in config["config_groups"].items(): + targets = quant_config.get("targets") + for target in targets: + layer_quant_details[target] = {} + layer_quant_details[target]["weight"] = quant_config.get( + "weights") + layer_quant_details[target]["input"] = quant_config.get( + "input_activations") + + return cls(layer_quant_details=layer_quant_details, + ignore=ignore, + fake_quant=fake_quant) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["config.json"] + + def _get_schema(self, weight_quant: Dict, input_quant: Dict): + # TODO: Will static vs dynamic be defined in the config? + # TODO: Expand conditions/break into separate fxs as other + # schemes are supported + + weight_bit = weight_quant.get("num_bits") + input_bit = input_quant.get("num_bits") + + weight_strategy = weight_quant.get("strategy") + input_strategy = input_quant.get("strategy") + + weight_symmetric = weight_quant.get("symmetric") + input_symmetric = input_quant.get("symmetric") + + is_8_bits = weight_bit == input_bit == 8 + is_tensor = weight_strategy == input_strategy == "tensor" + is_symmetric = weight_symmetric and input_symmetric + + if is_8_bits and is_tensor and is_symmetric: + return CompressedTensorsW8A8StaticTensor( + fake_quant=self.fake_quant) + raise NotImplementedError( + "Scheme not supported. Only 8-bit static symmtetric " + "per tensor quantization is currently supported") + + def get_scheme(self, layer: torch.nn.Module, + layer_name: str) -> "CompressedTensorsScheme": + + if layer_name is None: + raise ValueError("layer_name must be provided for CompressedTensorsConfig") + + if layer_name in self.ignore: + return CompressedTensorsUnquantized() + + # TODO: update with matching function from `compressed_tensors` + layer_type_name = None + layer_name_class = type(layer).__name__.lower() + for target in self.layer_quant_details: + if target.lower() in layer_name_class: + layer_type_name = target + break + + layer_quant_details = self.layer_quant_details.get(layer_type_name) + if layer_quant_details is None: + raise ValueError( + f"Could not find quantization details for {layer_name}.") + try: + return self._get_schema(weight_quant=layer_quant_details["weight"], + input_quant=layer_quant_details["input"]) + except NotImplementedError as e: + raise e + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create the + necessary parameters for the layer. + """ + weight_loader = extra_weight_attrs.get("weight_loader") + + scheme = self.quantization_config.get_scheme(layer=layer, + layer_name=layer_name) + scheme.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + layer.scheme = scheme + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme associated with + the layer to apply the forward pass with the layer input. + """ + + if bias is not None: + raise ValueError("bias is not supported for this linear method") + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py new file mode 100644 index 000000000000..1b728865641d --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -0,0 +1,91 @@ +import cutlass +from cutlass import Tensor as FakeTensor +import cutlass.epilogue + +import torch +from typing import Optional, Tuple, Dict + +from vllm.logger import init_logger + +logger = init_logger("cutlass_gemm") + +def setup_dequant_epilogue(plan : cutlass.op.Gemm, + dq: torch.Tensor, + static_scales: Optional[torch.Tensor], + activation_scales: Optional[torch.Tensor]) \ + -> Tuple[cutlass.op.Gemm, Dict]: + + if all([static_scales is None, activation_scales is None]): + return plan, None + assert static_scales is not None + + def epilog_with_scales_and_act_scales(accum, scales, act_scales): + D = accum * scales * act_scales + return D + + def epilog_with_scales(accum, scales): + D = accum * scales + return D + + epilog_tensors = {'scales': static_scales, 'D': dq} + epilogue_trace_tensors = { + "accum": + FakeTensor(element=torch.int32, + shape=dq.shape, + layout_tag=cutlass.LayoutType.RowMajor), + 'scales': + static_scales, + 'D': + dq, + } + epilog_fn = epilog_with_scales + + if activation_scales is not None: + epilog_tensors['act_scales'] = activation_scales + epilogue_trace_tensors['act_scales'] = activation_scales + epilog_fn = epilog_with_scales_and_act_scales + + plan.epilogue_visitor = cutlass.epilogue.trace(epilog_fn, + epilogue_trace_tensors) + return plan, epilog_tensors + + +def cutlass_gemm_dq( + x_q: torch.Tensor, + w_q: torch.Tensor, + dtype: torch.dtype, + static_scales: torch.Tensor, + activation_scales: Optional[torch.Tensor] = None) -> torch.Tensor: + + dq = torch.empty((x_q.shape[0], w_q.shape[0]), dtype=dtype, device="cuda") + + log_str = (f"cutlass_gemm_dq: \n" + f" - x_q {x_q.shape} {x_q.dtype} \n" + f" - w_q {w_q.shape} {w_q.dtype} \n" + f" - o_dq {dq.shape} {dq.dtype} \n") + logger.debug(log_str) + + plan = cutlass.op.Gemm( + element_A=x_q.dtype, + element_B=w_q.dtype, + element_C=dq.dtype, + element_D=dq.dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.int32) + + plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, + activation_scales) + + plan.run(x_q, + w_q.t(), + dq, + dq, + alpha=1, + beta=0, + visitor_args=visitor_args, + print_module=False) + + dq = dq.view(*x_q.shape[:-1], -1) + return dq diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 000000000000..5a32069d71e2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,3 @@ +from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_unquantized import CompressedTensorsUnquantized +from .compressed_tensors_w8a8_statictensor import CompressedTensorsW8A8StaticTensor \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 000000000000..1873cba9b681 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass of different + quantization schemes supported by CompressedTensors. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + """ + Run the forward pass for the particular scheme. This is where scheme-specific + dequant/quant steps/kernels should be applied. + + :param layer: toch.nn.Module with the registered weights and other parameters + relevant to the particular scheme. + :param x: input to the layer + + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py new file mode 100644 index 000000000000..d5b582f6176a --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -0,0 +1,36 @@ +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +import torch +from typing import List, Callable +from torch.nn import Parameter +from vllm.model_executor.utils import set_weight_attrs +import torch.nn.functional as F + +__all__ = ["CompressedTensorsUnquantized"] + + +class CompressedTensorsUnquantized(CompressedTensorsScheme): + """ + Implements the scheme for all layers which are ignored in the CompressedTensors + config. The input and loaded weight are used in a linear transformation. + """ + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + return F.linear(x, weight) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py new file mode 100644 index 000000000000..9698e97f91f4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -0,0 +1,137 @@ +import torch +from typing import List, Union, Tuple, Callable +from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( + cutlass_gemm_dq) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs +from torch.nn import Parameter +from vllm._C import ops + +__all__ = ["CompressedTensorsW8A8StaticTensor"] + + +class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): + + def __init__(self, fake_quant): + self.fake_quant = fake_quant + + def _quantize(self, + x: torch.Tensor, + scales: torch.Tensor, + logical_widths: List[int], + split_dim: int = 0) -> torch.Tensor: + + x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") + x_q_split = x_q.split(logical_widths, dim=split_dim) + x_split = x.split(logical_widths, dim=split_dim) + + for q, dq, scale in zip(x_q_split, x_split, scales): + ops.quant_per_tensor(q, dq, scale.item()) + + return x_q + + def _quantize_single(self, x: torch.Tensor, scale: float): + x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") + ops.quant_per_tensor(x_q, x, scale) + return x_q + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self._shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # TODO: remove zero_point parameters once the configs given remove them + is_tensor_partitioned = len(output_partition_sizes) != 1 + dim = sum(output_partition_sizes) if is_tensor_partitioned else 1 + + input_scale = Parameter(torch.empty(1, + device="cuda", + dtype=torch.float32), + requires_grad=False) + input_zero_point = Parameter(torch.empty(1, + device="cuda", + dtype=torch.int8), + requires_grad=False) + + weight_scale = Parameter(torch.empty(dim, + device="cuda", + dtype=torch.float32), + requires_grad=False) + weight_zero_point = Parameter(torch.empty(1, + device="cuda", + dtype=torch.int8), + requires_grad=False) + + if not self.fake_quant: + params_dtype = torch.int8 + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + # Register parameter with the layer; register weight loader with each parameter + set_weight_attrs(weight, {"weight_loader": weight_loader}) + set_weight_attrs(weight, + {"logical_widths": output_partition_sizes}) + + layer.register_parameter("input_scale", input_scale) + set_weight_attrs(input_scale, {"weight_loader": weight_loader}) + layer.register_parameter("input_zero_point", input_zero_point) + set_weight_attrs(input_zero_point, {"weight_loader": weight_loader}) + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + set_weight_attrs( + weight_scale, { + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_partition_sizes + }) + layer.register_parameter("weight_zero_point", weight_zero_point) + set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + weight_scale = layer.weight_scale + act_scale = layer.input_scale + logical_widths = weight.logical_widths + + # Input quantize + x_q = self._quantize_single(x, act_scale[0].item()) + + # Weight quantize + # TODO : try not to remove device-to-host copy. i.e. keep the non-duplicated version + # of scales in the CPU + if self.fake_quant: + w_scales = [ + weight_scale[sum(logical_widths[:i])].item() + for i in range(len(logical_widths)) + ] + w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) + w_q = self._quantize(weight, w_scales, logical_widths) + # GEMM and dq + return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, act_scale) + return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, act_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ba9f3149649c..05347ba17707 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -79,6 +79,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index ae9f7019f059..b58db2ae7e7f 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -98,6 +98,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 94aba620ea08..6644d1e269ff 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -100,6 +100,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 207dbcee8afc..4a4627f7e896 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -75,6 +75,7 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f6d7fc8733fc..b723ce43b89e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,6 +54,7 @@ class LlamaMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -61,13 +62,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -84,6 +89,7 @@ class LlamaAttention(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -127,16 +133,18 @@ def __init__( self.kv_scale = 1.0 self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=bias, quant_config=quant_config, ) @@ -174,6 +182,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -193,6 +202,7 @@ def __init__( attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) self.self_attn = LlamaAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", @@ -205,6 +215,7 @@ def __init__( sliding_window=sliding_window, ) self.mlp = LlamaMLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -265,8 +276,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + quant_config=quant_config) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0704f5fec54d..5b6a29912396 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1101,4 +1101,4 @@ def _prepare_fake_inputs( else: prompt_tokens = [0] * seq_len fake_image_input = None - return SequenceData(prompt_tokens), fake_image_input + return SequenceData(prompt_tokens), fake_image_input \ No newline at end of file From 92b370393c4187be8deb25ae8f0c090b2bda736f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 18:52:02 +0000 Subject: [PATCH 2/6] add get_quant method to compressed tensors config --- .../quantization/compressed_tensors/compressed_tensors.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index a61bec6e0323..d9eec5da6c9f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -2,7 +2,7 @@ import torch -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) @@ -35,6 +35,12 @@ def get_min_capability(self) -> int: def get_name(self) -> str: return "compressed_tensors" + + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["CompressedTensorsLinearMethod"]: + if isinstance(layer, LinearBase): + return CompressedTensorsLinearMethod(self) + return None @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": From 2a3eb8385ebf8cd7749ffec05481919c778e7be1 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 19:06:40 +0000 Subject: [PATCH 3/6] small rebase fixed --- vllm/model_executor/layers/linear.py | 4 ++-- vllm/model_executor/layers/quantization/__init__.py | 2 +- .../quantization/compressed_tensors/compressed_tensors.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5469898972e4..9f7d9d77b047 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Optional import torch @@ -224,7 +224,7 @@ def __init__( layer_name: Optional[str] = None ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, layer_name) self.gather_output = gather_output diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 06fb0c905623..1607efddb657 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,7 +4,7 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import FP8Config +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) from vllm.model_executor.layers.quantization.gptq import GPTQConfig diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d9eec5da6c9f..5889ce469ae2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -147,7 +147,7 @@ def create_weights(self, layer: torch.nn.Module, layer.scheme = scheme - def apply_weights(self, + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None): From 3dd1fe8857aa518121927060054c9f3c6f8141f8 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 19:25:58 +0000 Subject: [PATCH 4/6] format --- vllm/model_executor/layers/linear.py | 148 +++++++++--------- .../layers/quantization/aqlm.py | 9 +- .../model_executor/layers/quantization/awq.py | 9 +- .../compressed_tensors/compressed_tensors.py | 46 +++--- .../compressed_tensors/cutlass_gemm.py | 17 +- .../compressed_tensors_w8a8_statictensor.py | 3 +- .../layers/quantization/squeezellm.py | 9 +- 7 files changed, 126 insertions(+), 115 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9f7d9d77b047..c6dcd48ac765 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -30,13 +30,15 @@ class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs) -> Dict[str, Any]: - """Create weights for a linear layer. The weights will be set as attributes of the layer. @@ -74,10 +76,13 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs) -> Dict[str, Any]: weight = Parameter(torch.empty(sum(output_partition_sizes), @@ -113,15 +118,13 @@ class LinearBase(torch.nn.Module): layer_name: name of the layer in the state dict. """ - def __init__( - self, - input_size: int, - output_size: int, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): super().__init__() # Keep input parameters @@ -154,24 +157,25 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, layer_name) # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, self.input_size, - [self.output_size], self.input_size, - self.output_size, self.params_dtype, layer_name=self.layer_name) + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + layer_name=self.layer_name) if bias: self.bias = Parameter( @@ -211,18 +215,16 @@ class ColumnParallelLinear(LinearBase): layer_name: name of the layer in the state dict. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None, + layer_name: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, layer_name) @@ -310,17 +312,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -460,18 +460,16 @@ class QKVParallelLinear(ColumnParallelLinear): layer_name: name of the layer in the state dict. """ - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -652,18 +650,16 @@ class RowParallelLinear(LinearBase): layer_name: name of the layer in the state dict. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, layer_name) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 6edb3c3e9c63..1215f818de90 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -227,10 +227,13 @@ class AQLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: AQLMConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 00b4a4714be1..58e3fd0d1d84 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -82,10 +82,13 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5889ce469ae2..5b6001d79732 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -35,18 +35,19 @@ def get_min_capability(self) -> int: def get_name(self) -> str: return "compressed_tensors" - + def get_quant_method( - self, layer: torch.nn.Module) -> Optional["CompressedTensorsLinearMethod"]: + self, layer: torch.nn.Module + ) -> Optional["CompressedTensorsLinearMethod"]: if isinstance(layer, LinearBase): return CompressedTensorsLinearMethod(self) return None @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": - layer_quant_details: Dict[str:Any] = dict() - ignore = config.get("ignore") - fake_quant = config.get("format") == "fakequant" + layer_quant_details: Dict[str, Any] = dict() + ignore: List[str] = config.get("ignore", None) + fake_quant: bool = config.get("format") == "fakequant" for key, quant_config in config["config_groups"].items(): targets = quant_config.get("targets") @@ -66,9 +67,7 @@ def get_config_filenames(cls) -> List[str]: return ["config.json"] def _get_schema(self, weight_quant: Dict, input_quant: Dict): - # TODO: Will static vs dynamic be defined in the config? - # TODO: Expand conditions/break into separate fxs as other - # schemes are supported + # TODO: Refactor as additional cases are supported weight_bit = weight_quant.get("num_bits") input_bit = input_quant.get("num_bits") @@ -90,11 +89,14 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict): "Scheme not supported. Only 8-bit static symmtetric " "per tensor quantization is currently supported") - def get_scheme(self, layer: torch.nn.Module, - layer_name: str) -> "CompressedTensorsScheme": + def get_scheme( + self, + layer: torch.nn.Module, + layer_name: Optional[str] = None) -> "CompressedTensorsScheme": if layer_name is None: - raise ValueError("layer_name must be provided for CompressedTensorsConfig") + raise ValueError( + "layer_name must be provided for CompressedTensorsConfig") if layer_name in self.ignore: return CompressedTensorsUnquantized() @@ -106,8 +108,11 @@ def get_scheme(self, layer: torch.nn.Module, if target.lower() in layer_name_class: layer_type_name = target break + if layer_type_name is None: + raise ValueError(f"Could not matching target for layer {layer}") - layer_quant_details = self.layer_quant_details.get(layer_type_name) + layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( + layer_type_name, None) if layer_quant_details is None: raise ValueError( f"Could not find quantization details for {layer_name}.") @@ -123,10 +128,13 @@ class CompressedTensorsLinearMethod(LinearMethodBase): def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs): """ @@ -146,11 +154,11 @@ def create_weights(self, layer: torch.nn.Module, weight_loader=weight_loader) layer.scheme = scheme - + def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the layer input. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py index 1b728865641d..1766aed1d692 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -65,15 +65,14 @@ def cutlass_gemm_dq( f" - o_dq {dq.shape} {dq.dtype} \n") logger.debug(log_str) - plan = cutlass.op.Gemm( - element_A=x_q.dtype, - element_B=w_q.dtype, - element_C=dq.dtype, - element_D=dq.dtype, - layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.ColumnMajor, - layout_C=cutlass.LayoutType.RowMajor, - element_accumulator=torch.int32) + plan = cutlass.op.Gemm(element_A=x_q.dtype, + element_B=w_q.dtype, + element_C=dq.dtype, + element_D=dq.dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.int32) plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, activation_scales) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 9698e97f91f4..38b810f1c9ab 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -96,8 +96,7 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) # Register parameter with the layer; register weight loader with each parameter set_weight_attrs(weight, {"weight_loader": weight_loader}) - set_weight_attrs(weight, - {"logical_widths": output_partition_sizes}) + set_weight_attrs(weight, {"logical_widths": output_partition_sizes}) layer.register_parameter("input_scale", input_scale) set_weight_attrs(input_scale, {"weight_loader": weight_loader}) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 4a4627f7e896..6f408b491f1a 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -71,10 +71,13 @@ class SqueezeLLMLinearMethod(QuantizeMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: From f2f8c5261acb51d378b007482955ef54debaf80f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 30 Apr 2024 20:04:21 +0000 Subject: [PATCH 5/6] fix mypy complaints --- .../compressed_tensors/cutlass_gemm.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py index 1766aed1d692..b3eccbdf6fec 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -3,27 +3,31 @@ import cutlass.epilogue import torch -from typing import Optional, Tuple, Dict +from typing import Optional, Tuple, Dict, Union, Any from vllm.logger import init_logger logger = init_logger("cutlass_gemm") +# type alias +TF = Union[torch.Tensor, float] + def setup_dequant_epilogue(plan : cutlass.op.Gemm, - dq: torch.Tensor, - static_scales: Optional[torch.Tensor], - activation_scales: Optional[torch.Tensor]) \ - -> Tuple[cutlass.op.Gemm, Dict]: + dq : torch.Tensor, + static_scales: Optional[TF], + activation_scales: Optional[TF]) \ + -> Tuple[cutlass.op.Gemm, Optional[Dict]]: if all([static_scales is None, activation_scales is None]): return plan, None assert static_scales is not None - def epilog_with_scales_and_act_scales(accum, scales, act_scales): + def epilog_with_scales_and_act_scales(accum: torch.Tensor, scales: TF, + act_scales: TF) -> torch.Tensor: D = accum * scales * act_scales return D - def epilog_with_scales(accum, scales): + def epilog_with_scales(accum: torch.Tensor, scales: TF) -> torch.Tensor: D = accum * scales return D @@ -38,7 +42,7 @@ def epilog_with_scales(accum, scales): 'D': dq, } - epilog_fn = epilog_with_scales + epilog_fn: Any = epilog_with_scales if activation_scales is not None: epilog_tensors['act_scales'] = activation_scales From d9d49b5224dccb16eb28628ed9fb5f95b07437cc Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 21:25:31 +0000 Subject: [PATCH 6/6] format fixes --- vllm/model_executor/layers/linear.py | 10 +++++----- .../layers/quantization/__init__.py | 2 +- .../compressed_tensors/compressed_tensors.py | 15 +++++++-------- .../compressed_tensors/cutlass_gemm.py | 6 +++--- .../compressed_tensors/schemes/__init__.py | 8 +++++--- .../schemes/compressed_tensors_scheme.py | 13 +++++++------ .../schemes/compressed_tensors_unquantized.py | 15 +++++++++------ .../compressed_tensors_w8a8_statictensor.py | 16 +++++++++------- .../layers/quantization/gptq_marlin.py | 1 + 9 files changed, 47 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c6dcd48ac765..d155b5704d5a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -38,7 +38,7 @@ def create_weights(self, output_size: int, params_dtype: torch.dtype, layer_name: Optional[str] = None, - **extra_weight_attrs) -> Dict[str, Any]: + **extra_weight_attrs): """Create weights for a linear layer. The weights will be set as attributes of the layer. @@ -84,7 +84,7 @@ def create_weights(self, output_size: int, params_dtype: torch.dtype, layer_name: Optional[str] = None, - **extra_weight_attrs) -> Dict[str, Any]: + **extra_weight_attrs): weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), @@ -413,7 +413,7 @@ def weight_loader(self, param_data = param_data.narrow(0, shard_offset, shard_size) # If a param_shard_splitter is defined by the LinearMethod, use it. elif param_shard_splitter is not None: - logical_widths = getattr(param, "logical_widths") + logical_widths = getattr(param, "logical_widths", None) param_data, loaded_weight = param_shard_splitter( param_data, loaded_weight, loaded_shard_id, logical_widths) @@ -601,7 +601,7 @@ def weight_loader(self, shard_size) # If a param_shard_splitter is defined by the LinearMethod, use it. elif param_shard_splitter is not None: - logical_widths = getattr(param, "logical_widths") + logical_widths = getattr(param, "logical_widths", None) param_data, loaded_weight = param_shard_splitter( param_data, loaded_weight, loaded_shard_id, logical_widths) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 14052dc72583..73fd41d7656e 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,9 +4,9 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5b6001d79732..599cce689d65 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -5,11 +5,9 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) - from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsW8A8StaticTensor, CompressedTensorsUnquantized, - CompressedTensorsScheme) -from vllm.model_executor.utils import set_weight_attrs + CompressedTensorsScheme, CompressedTensorsUnquantized, + CompressedTensorsW8A8StaticTensor) class CompressedTensorsConfig(QuantizationConfig): @@ -138,8 +136,8 @@ def create_weights(self, layer_name: Optional[str] = None, **extra_weight_attrs): """ - Use the CompressedTensorsScheme associated with each layer to create the - necessary parameters for the layer. + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. """ weight_loader = extra_weight_attrs.get("weight_loader") @@ -160,8 +158,9 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None): """ - Use the output of create_weights and the CompressedTensorsScheme associated with - the layer to apply the forward pass with the layer input. + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. """ if bias is not None: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py index b3eccbdf6fec..72720a934227 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -1,9 +1,9 @@ +from typing import Any, Dict, Optional, Tuple, Union + import cutlass -from cutlass import Tensor as FakeTensor import cutlass.epilogue - import torch -from typing import Optional, Tuple, Dict, Union, Any +from cutlass import Tensor as FakeTensor from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 5a32069d71e2..831905b63e2c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,3 +1,5 @@ -from .compressed_tensors_scheme import CompressedTensorsScheme -from .compressed_tensors_unquantized import CompressedTensorsUnquantized -from .compressed_tensors_w8a8_statictensor import CompressedTensorsW8A8StaticTensor \ No newline at end of file +from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 +from .compressed_tensors_unquantized import ( # noqa: F401 + CompressedTensorsUnquantized) +from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 + CompressedTensorsW8A8StaticTensor) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 1873cba9b681..3a5904208656 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod + import torch __all__ = ["CompressedTensorsScheme"] @@ -6,8 +7,8 @@ class CompressedTensorsScheme(ABC): """ - Abstract class used to describe the weight creation and forward pass of different - quantization schemes supported by CompressedTensors. + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. """ @abstractmethod @@ -21,11 +22,11 @@ def create_weights(self, *args, **kwargs): @abstractmethod def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): """ - Run the forward pass for the particular scheme. This is where scheme-specific - dequant/quant steps/kernels should be applied. + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. - :param layer: toch.nn.Module with the registered weights and other parameters - relevant to the particular scheme. + :param layer: toch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. :param x: input to the layer """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index d5b582f6176a..0cfac13d1ca2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -1,18 +1,21 @@ -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) +from typing import Callable, List + import torch -from typing import List, Callable +import torch.nn.functional as F from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs -import torch.nn.functional as F __all__ = ["CompressedTensorsUnquantized"] class CompressedTensorsUnquantized(CompressedTensorsScheme): """ - Implements the scheme for all layers which are ignored in the CompressedTensors - config. The input and loaded weight are used in a linear transformation. + Implements the scheme for all layers which are ignored + in the CompressedTensors config. The input and loaded weight are used + in a linear transformation. """ def create_weights(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 38b810f1c9ab..03252882d2ed 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -1,12 +1,14 @@ +from typing import Callable, List, Tuple, Union + import torch -from typing import List, Union, Tuple, Callable -from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( +from torch.nn import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 cutlass_gemm_dq) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs -from torch.nn import Parameter -from vllm._C import ops __all__ = ["CompressedTensorsW8A8StaticTensor"] @@ -94,7 +96,7 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight", weight) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - # Register parameter with the layer; register weight loader with each parameter + set_weight_attrs(weight, {"weight_loader": weight_loader}) set_weight_attrs(weight, {"logical_widths": output_partition_sizes}) @@ -122,8 +124,8 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): x_q = self._quantize_single(x, act_scale[0].item()) # Weight quantize - # TODO : try not to remove device-to-host copy. i.e. keep the non-duplicated version - # of scales in the CPU + # TODO : try not to remove device-to-host copy. + # i.e. keep the non-duplicated version of scales in the CPU if self.fake_quant: w_scales = [ weight_scale[sum(logical_widths[:i])].item() diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index efbffa0878c4..07e57302d9a8 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -206,6 +206,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ) -> None: del output_size