diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f..6b548d5e8921 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/examples/offline_quantized_inference.py b/examples/offline_quantized_inference.py index 124a99468704..8b3dbea72ae6 100644 --- a/examples/offline_quantized_inference.py +++ b/examples/offline_quantized_inference.py @@ -17,7 +17,7 @@ # Create an LLM. llm = LLM( - model=model_path, + model="nm-testing/Nous-Hermes-Llama2-13b-smoothquant", gpu_memory_utilization=0.9, max_model_len=2048, quantization="smoothquant", diff --git a/examples/simple_test.py b/examples/simple_test.py new file mode 100644 index 000000000000..dcf8b8c7ed1e --- /dev/null +++ b/examples/simple_test.py @@ -0,0 +1,35 @@ +import argparse +from vllm import LLM, SamplingParams + +MODELS = { + "tinyllama-fp16": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "tinyllama-marlin": "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", + "tinyllama-gptq": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + "tinyllama-awq": "TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ", +} + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str) +parser.add_argument("--tensor-parallel-size", type=int, default=1) +args = parser.parse_args() + +if args.model not in MODELS: + print(f"Got model id of {args.model}; Must be in {list(MODELS.keys())}") + raise ValueError +else: + model_id = MODELS[args.model] + print(f"Using model_id = {model_id}") + +messages=[{ + "role": "system", + "content": "You are a helpful assistant." +}, { + "role": "user", + "content": "What is deep learning?" +}] + +model = LLM(model_id, enforce_eager=True, max_model_len=2048, tensor_parallel_size=args.tensor_parallel_size, dtype="float16") +prompt = model.llm_engine.tokenizer.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) +out = model.generate(prompt, SamplingParams(max_tokens=50)) +print(f"\n-----prompt\n{prompt}") +print(f"\n-----generation\n{out[0].outputs[0].text}") diff --git a/vllm/config.py b/vllm/config.py index 3149aaf68914..cd48fe4f1b9d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -173,8 +173,8 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm", "smoothquant"] - rocm_not_supported_quantization = ["awq", "marlin"] + supported_quantization = ["awq", "gptq", "marlin", "squeezellm", "smoothquant"] + rocm_not_supported_quantization = ["awq", "marlin", "smoothquant"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3d4d1789db2..2598156bbed3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -29,8 +29,11 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + def create_weights(self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: List[int], + input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: """Create weights for a linear layer.""" @@ -43,6 +46,12 @@ def apply_weights(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights to the input tensor.""" raise NotImplementedError + + def maybe_update_loaded_weight_name(self, name: str) -> str: + """Update the name of a loaded weight to enable generic handling of + cases where serialized state_dict does not match vllm model definition. + """ + return name class UnquantizedLinearMethod(LinearMethodBase): @@ -56,17 +65,20 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, + def create_weights(self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: List[int], + input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - weight = Parameter(torch.empty(output_size_per_partition, + weight = Parameter(torch.empty(sum(output_sizes_per_partition), input_size_per_partition, dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) return {"weight": weight} + def apply_weights(self, weights: Dict[str, torch.Tensor], x: torch.Tensor, @@ -83,6 +95,7 @@ class ReplicatedLinear(torch.nn.Module): """Replicated linear layer. Args: + layer_name: name of the layer in the state dict. input_size: input dimension of the linear layer. output_size: output dimension of the linear layer. bias: If true, add bias. @@ -93,6 +106,7 @@ class ReplicatedLinear(torch.nn.Module): def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -103,6 +117,7 @@ def __init__( super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add @@ -113,8 +128,8 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) + self.layer_name, self.input_size, [self.output_size], + self.input_size, self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -139,6 +154,7 @@ class ColumnParallelLinear(torch.nn.Module): its second dimension as A = [A_1, ..., A_p]. Args: + layer_name: name of the layer in the state dict. input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. @@ -150,10 +166,14 @@ class ColumnParallelLinear(torch.nn.Module): skip adding bias but instead return it. params_dtype: Data type for the parameters. linear_method: (Maybe quantized) linear method. + logical_widths: Optional list of widths for logical weight matrices. + E.g. for QKVParallelLinear, this parameter defines + the width """ def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -165,12 +185,20 @@ def __init__( super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_sizes_per_partition = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if self.output_sizes is not None: + self.output_sizes_per_partition = [ + divide(output_size, tp_size) for output_size in self.output_sizes + ] + self.skip_bias_add = skip_bias_add if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -179,8 +207,13 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) + layer_name=self.layer_name, + input_size_per_partition=self.input_size, + output_sizes_per_partition=self.output_sizes_per_partition, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + ) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -246,6 +279,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): def __init__( self, + layer_name: str, input_size: int, output_sizes: List[int], bias: bool = True, @@ -257,8 +291,15 @@ def __init__( self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method) + super().__init__( + layer_name=layer_name, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + linear_method=linear_method) def weight_loader(self, param: Parameter, @@ -266,6 +307,18 @@ def weight_loader(self, loaded_shard_id: Optional[int] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + param_shard_splitter = getattr(param, "shard_splitter", None) + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -318,6 +371,10 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -325,6 +382,7 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -340,6 +398,7 @@ class QKVParallelLinear(ColumnParallelLinear): be replicated while the query heads are partitioned. Args: + layer_name: name of the layer in the state dict. hidden_size: input hidden state size of the transformer. head_size: size of each attention head. total_num_heads: total number of attention query heads. @@ -355,6 +414,7 @@ class QKVParallelLinear(ColumnParallelLinear): def __init__( self, + layer_name: str, hidden_size: int, head_size: int, total_num_heads: int, @@ -383,8 +443,21 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, linear_method) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__( + layer_name=layer_name, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + linear_method=linear_method) def weight_loader(self, param: Parameter, @@ -392,6 +465,18 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) if loaded_shard_id is None: # Loaded weight is already packed. @@ -427,6 +512,8 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 @@ -450,15 +537,19 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) + param_data = param_data.narrow( + output_dim, shard_offset, shard_size) if loaded_shard_id == "q": shard_id = tp_rank else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -466,7 +557,11 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") - assert param_data.shape == loaded_weight.shape + + assert ( + param_data.shape == loaded_weight.shape or + (len(param_data.shape) == 0 and len(loaded_weight.shape) == 0) + ) param_data.copy_(loaded_weight) @@ -483,6 +578,7 @@ class RowParallelLinear(torch.nn.Module): | A_p | - - Arguments: + layer_name: name of the layer in the state dict. input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. Note that bias is not parallelized. @@ -498,6 +594,7 @@ class RowParallelLinear(torch.nn.Module): def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -509,6 +606,7 @@ def __init__( ): super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.input_is_parallel = input_is_parallel @@ -525,8 +623,13 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) + layer_name=self.layer_name, + input_size_per_partition=self.input_size_per_partition, + output_sizes_per_partition=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + ) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -555,6 +658,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # TODO: canon + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf5..7cf94ae9f44e 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,10 +79,17 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def create_weights( + self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + del layer_name, input_size, output_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 6115e7c3be95..868e09252bb2 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -51,14 +51,12 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @abstractmethod - def get_linear_method(self) -> LinearMethodBase: + def get_linear_method(self, name) -> LinearMethodBase: """Get the linear method to use for the quantized linear layer.""" raise NotImplementedError - @abstractmethod def get_scaled_act_names(self) -> List[str]: """Returns the activation function names that should be post-scaled. - For now, this is only used by AWQ. """ - raise NotImplementedError + return [] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed81..8c3492ae67d8 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -89,13 +89,16 @@ def __init__(self, quant_config: GPTQConfig): def create_weights( self, + layer_name: str, input_size_per_partition: int, - output_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - del output_size # Unused. + del output_size, layer_name # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf..59d217567919 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -91,13 +91,15 @@ def __init__(self, quant_config: MarlinConfig): def create_weights( self, + layer_name: str, input_size_per_partition: int, - output_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - del output_size # Unused. + del layer_name, input_size, output_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) if params_dtype != torch.float16: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/smoothquant.py b/vllm/model_executor/layers/quantization/smoothquant.py deleted file mode 100644 index d9d82e6cfbc3..000000000000 --- a/vllm/model_executor/layers/quantization/smoothquant.py +++ /dev/null @@ -1,348 +0,0 @@ -from typing import Any, Dict, List, Tuple, Optional - -import torch -from torch._tensor import Tensor -from torch.nn.parameter import Parameter -import threading - -from vllm._C import ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig - - -class SmoothQuantConfig(QuantizationConfig): - """Config class for SmoothQuant - - Reference: https://github.com/mit-han-lab/smoothquant - """ - - def __init__(self, - weight_bits: int = 8, - quant_map: dict[str:str] = None) -> None: - self.weight_bits = weight_bits - self.quant_map = quant_map - - if self.weight_bits != 8: - raise ValueError( - "Currently, only w8a8 quantization is supported for " - f"SmoothQuant, but got {self.weight_bits} bits.") - if self.quant_map is None or self.quant_map == {}: - raise ValueError( - 'Quant_map for SmoothQuant should not be None or an empty dict. ' - 'For example, when using llama, you should set a quant_config.json in model directory, like ' - '{ "qkv": "per-tensor", "out": "per-token", "fc1": "per-tensor", "fc2": "per-token" }' - ) - - def __repr__(self) -> str: - return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, " - f"quant_map={self.quant_map})") - - def get_name(self) -> str: - return "smoothquant" - - def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half, torch.float] - - def get_min_capability(self) -> int: - # The smoothquant kernel only supports Ampere or newer GPUs. - return 80 - - @classmethod - def get_config_filenames(cls) -> List[str]: - """List of filenames to search for in the model directory.""" - return [ - "quant_config.json", - "quantize_config.json", - ] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": - try: - weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - except ValueError as e: - weight_bits = 8 - print(str(e) + " Set weight_bits = 8 by default.") - - quant_map = {} - for key, value in config.items(): - if value in ["per-tensor", "per-token"]: - quant_map[key] = value - return cls(weight_bits, quant_map) - - def get_linear_method(self) -> "SQLinearMethod": - return SQLinearMethod(Int8GEMM) - - def get_scaled_act_names(self) -> List[str]: - return [] - - -class Int8GEMM(object): - _instance_lock = threading.Lock() - - def __init__(self): - if not hasattr(self, "i8cugemm"): - self.i8cugemm = ops.I8CUGEMM() - - def __new__(cls, *args, **kwargs): - if not hasattr(Int8GEMM, "_instance"): - with Int8GEMM._instance_lock: - if not hasattr(Int8GEMM, "_instance"): - Int8GEMM._instance = object.__new__(cls) - return Int8GEMM._instance - - def get_i8cugemm(self): - return self.i8cugemm - - -class SQLinearMethod(LinearMethodBase): - """Linear method for SmoothQuant. - """ - - def __init__(self, gemm): - i8_gemm = gemm() - self.i8cugemm = i8_gemm.get_i8cugemm() - - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Tensor]: - weight = Parameter( - torch.empty( - output_size_per_partition, - input_size_per_partition, - device="cuda", - dtype=torch.int8, - ), - requires_grad=False, - ) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - }) - # q k v dequant_scales are used in QKVParallelLinear - q_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - k_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - v_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - # gate up dequant_scales are used in MergedColumnParallelLinear - gate_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - up_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - # dequant_scale is used in RowParallelLinear - dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - return { - "weight": weight, - "q_dequant_scale": q_dequant_scale, - "k_dequant_scale": k_dequant_scale, - "v_dequant_scale": v_dequant_scale, - "gate_dequant_scale": gate_dequant_scale, - "up_dequant_scale": up_dequant_scale, - "dequant_scale": dequant_scale - } - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> Tensor: - assert bias is None - weight = weights["weight"] - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = torch.empty((x.shape[0], weight.shape[0]), - dtype=torch.int32, - device=x.device) - self.i8cugemm.linear_a8_w8_o32_(x, weight, y) - y = y.view(*x_shape[:-1], -1) - return y - - -class SQLinearMethodQKV(SQLinearMethod): - - def __init__(self, - gemm, - qkv_sizes : Tuple[int, int, int], - quant_dtype : torch.dtype = torch.int8, - dequant_dtype : torch.dtype = torch.float): - super().__init__(gemm) - self.qkv_sizes = qkv_sizes - self.quant_dtype = quant_dtype - self.dequant_dtype = dequant_dtype - - def quantize(self, x): - assert x.dtype != self.quant_dtype - x_q = torch.empty_like(x, dtype=self.quant_dtype) - ops.quant(x_q, x, 1.0) - return x_q - - def dequantize(self, x_q, weights : Dict[str, Tensor]): - # split to get the quantized qkv - q_q, k_q, v_q = x_q.split(list(self.qkv_sizes), dim=-1) - - # create dequant qkv buffer and split to get the individual dequant qkv - # buffers - qkv = torch.empty_like(x_q, dtype=self.dequant_dtype) - q, k, v = qkv.split(list(self.qkv_sizes), dim=-1) - - q_scale, k_scale, v_scale = (weights['q_dequant_scale'], - weights['k_dequant_scale'], - weights['v_dequant_scale']) - ops.dequant(q, q_q, q_scale) - ops.dequant(k, k_q, k_scale) - ops.dequant(v, v_q, v_scale) - - return qkv - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> Tensor: - x_q = self.quantize(x) - y_q = super().apply_weights(weights, x_q, bias) - return self.dequantize(y_q, weights) - -class SQLinearMethodOProj(SQLinearMethod): - - def __init__(self, - gemm, - use_per_token_quant:bool, - quant_dtype : torch.dtype = torch.int8, - dequant_dtype : torch.dtype = torch.float): - super().__init__(gemm) - self.use_per_token_quant = use_per_token_quant - self.quant_dtype = quant_dtype - self.dequant_dtype = dequant_dtype - - def quantize(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # x is the paged-attention output - assert x.dtype != self.quant_dtype - act_scale = None - x_q = torch.empty_like(x, dtype=self.quant_dtype) - if self.use_per_token_quant: - act_scale = torch.empty(x.numel() // x.shape[-1], - dtype=torch.float32, - device=x.device) - ops.quant(x_q, x, act_scale) - else: - ops.quant(x_q, x, 1.0) - return x_q, act_scale - - def dequantize(self, x_q, weights : Dict[str, Tensor], act_scale : torch.Tensor) -> torch.Tensor: - o_dequant_scale = weights['dequant_scale'] - x = torch.empty_like( - x_q, - dtype=self.dequant_dtype, - device=x_q.device) - ops.dequant(x, x_q, act_scale, o_dequant_scale) - return x - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> Tensor: - pass - x_q, act_scale = self.quantize(x) - y_q = super().apply_weights(weights, x_q, bias) - return self.dequantize(y_q, weights, act_scale) - -class SQLinearMethodGateUpProj(SQLinearMethod): - - def __init__(self, - gemm, - quant_dtype : torch.dtype = torch.int8, - dequant_dtype : torch.dtype = torch.float): - super().__init__(gemm) - self.quant_dtype = quant_dtype - self.dequant_dtype = dequant_dtype - - def quantize(self, x) -> torch.Tensor: - # x is the attention output - assert x.dtype != self.quant_dtype - x_q = torch.empty_like(x, dtype=self.quant_dtype, device=x.device) - ops.quant(x_q, x, 1.0) - return x_q - - def dequantize(self, gate_up_q: torch.Tensor, weights : Dict[str, Tensor]) -> torch.Tensor: - - def split_gate_up(gate_up : torch.Tensor): - d = gate_up.shape[-1] - return (torch.narrow(gate_up, 1, 0, d//2), - torch.narrow(gate_up, 1, d//2, d//2)) - - # create a dequant gate_up buffer and split it into constituent parts. - gate_up = torch.empty_like(gate_up_q, - dtype=self.dequant_dtype, - device=gate_up_q.device) - - # split quantized gate_up into constituent parts. - gate_q, up_q = split_gate_up(gate_up_q) - # split output gate_up buffer into constituent parts. - gate, up = split_gate_up(gate_up) - - gate_scale, up_scale = (weights['gate_dequant_scale'], - weights['up_dequant_scale']) - ops.dequant(gate, gate_q, gate_scale) - ops.dequant(up, up_q, up_scale) - - return gate_up - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> Tensor: - x_q = self.quantize(x) - gate_up_q = super().apply_weights(weights, x_q, bias) - return self.dequantize(gate_up_q, weights) - -class SQLinearMethodDownProj(SQLinearMethod): - - def __init__(self, - gemm, - quant_dtype : torch.dtype = torch.int8, - dequant_dtype : torch.dtype = torch.float): - super().__init__(gemm) - self.quant_dtype = quant_dtype - self.dequant_dtype = dequant_dtype - - def quantize(self, x) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dtype != self.quant_dtype - # TODO (varun) : This is per-token quant - Read from config - x_q = torch.empty_like(x, dtype=self.quant_dtype) - scale = torch.empty(x.numel() // x.shape[-1], - dtype=torch.float32, - device=x.device) - ops.quant(x_q, x, scale) - return x_q, scale - - def dequantize(self, x_q, weights : Dict[str, Tensor], act_scale : torch.Tensor) -> torch.Tensor: - down_dequant_scale = weights['dequant_scale'] - x = torch.empty_like( - x_q, - dtype=self.dequant_dtype, - device=x_q.device) - ops.dequant(x, x_q, act_scale, down_dequant_scale) - return x - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - x_q, act_scale = self.quantize(x) - y_q = super().apply_weights(weights, x_q, bias) - return self.dequantize(y_q, weights, act_scale) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/smoothquant/__init__.py b/vllm/model_executor/layers/quantization/smoothquant/__init__.py new file mode 100644 index 000000000000..2f62cee49d95 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/__init__.py @@ -0,0 +1,14 @@ +from vllm.model_executor.layers.quantization.smoothquant.formats import ( + SmoothQuantFormat +) + +from vllm.model_executor.layers.quantization.smoothquant.config import ( + SmoothQuantConfig, + SmoothQuantLinearMethod +) + +__all__ = [ + "SmoothQuantFormat", + "SmoothQuantConfig", + "SmoothQuantLinearMethod", +] diff --git a/vllm/model_executor/layers/quantization/smoothquant/config.py b/vllm/model_executor/layers/quantization/smoothquant/config.py new file mode 100644 index 000000000000..885ffce3e36d --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/config.py @@ -0,0 +1,306 @@ +from typing import Any, Dict, List, Tuple, Type, Optional, Union +import threading + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.smoothquant.formats import ( + SmoothQuantFormat, + SmoothQuantDynamicPerToken, + SmoothQuantStaticPerTensor, +) + +LAYER_KEYS = ["qkv", "out", "fc1", "fc2"] +FORMAT_REGISTRY = { + "per-token": SmoothQuantDynamicPerToken, + "per-tensor": SmoothQuantStaticPerTensor, +} + +def get_sq_format_cls(format_key: str) -> Type[SmoothQuantFormat]: + if format_key not in FORMAT_REGISTRY: + raise ValueError(f"Invalid smoothquant format: {format_key}") + return FORMAT_REGISTRY[format_key] + +class SmoothQuantConfig(QuantizationConfig): + """Config class for SmoothQuant. + + Reference: https://github.com/mit-han-lab/smoothquant + """ + def __init__(self, + layer_format_map: Dict[str, str]) -> None: + self.layer_format_map = layer_format_map + + for key, format in self.layer_format_map.items(): + if key not in LAYER_KEYS: + raise ValueError( + f"Found key of {key} in {self.layer_format_map}, " + f"but key must be one of {LAYER_KEYS}" + ) + if format not in FORMAT_REGISTRY: + raise ValueError( + f"Found format of {format} in {self.layer_format_map}, " + f"but format must be one of {FORMAT_REGISTRY}" + ) + for key in LAYER_KEYS: + if key not in self.layer_format_map: + raise ValueError( + f"Could not find {key} in {layer_format_map}" + ) + + def __repr__(self) -> str: + return (f"SmoothQuantConfig(layer_format_map={self.layer_format_map})") + + def get_name(self) -> str: + return "smoothquant" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + # TODO: check if we support fp16 / bf16 as well. + return [torch.float] + + def get_min_capability(self) -> int: + # TODO: check if this is right. + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": + layer_format_map: Dict[str, str] = {} + for layer_key, format in config.items(): + if format in FORMAT_REGISTRY: + layer_format_map[layer_key] = format + return cls(layer_format_map) + + def get_linear_method(self) -> "SmoothQuantLinearMethod": + return SmoothQuantLinearMethod(self) + + +# TODO: why is this needed? +class Int8GEMM(object): + _instance_lock = threading.Lock() + + def __init__(self): + if not hasattr(self, "i8cugemm"): + self.i8cugemm = ops.I8CUGEMM() + + def __new__(cls, *args, **kwargs): + if not hasattr(Int8GEMM, "_instance"): + with Int8GEMM._instance_lock: + if not hasattr(Int8GEMM, "_instance"): + Int8GEMM._instance = object.__new__(cls) + return Int8GEMM._instance + + def get_i8cugemm(self): + return self.i8cugemm + + +class SmoothQuantLinearMethod(LinearMethodBase): + def __init__(self, sq_config: SmoothQuantConfig) -> None: + self.sq_config = sq_config + self.sq_type = None + self.i8cugemm = Int8GEMM().get_i8cugemm() + + def maybe_update_loaded_weight_name(self, + name: str) -> str: + """Convert serialized name k_dequant_scale to dequant_scale. + + This function is called by model_cls.load_weights() during the weight + loading process to match on disk state dict to vllm state dict. + """ + if "dequant_scale" in name: + suffix = name.split('.')[-1] + name.replace(suffix, "dequant_scale") + return name + + def scales_shard_splitter(self, + param: torch.Tensor, + loaded_weight: torch.Tensor, + shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + """Index into param for for loading. + + This function is called by QKVColumnLinear and MergedColumnParallelLinear + during weight loading to put the scales from disk in the right spot. + """ + if type(shard_id) == str: + qkv_idxs = { "q": 0, "k": 1, "v": 2 } + if shard_id not in qkv_idxs: + raise ValueError(f"Invalid shard_id {shard_id}") + shard_id = qkv_idxs[shard_id] + elif type(shard_id) != int: + raise ValueError(f"Invalid shard id {shard_id}") + + return param[shard_id], loaded_weight + + def get_layer_format(self, layer_name: str) -> SmoothQuantFormat: + """ + Gets the SmoothQuantFormat for a specific layer. + + SmoothQuantLinearMethod uses SmoothQuantType to support non-uniform quantization + (where each layer has a different format). To determine the SmoothQuantFormat + for a layer, we match the layer_name to the layer_keys=["qkv","out","fc1","fc2"] + and use layer_format_map to to determine the SQFormat. + + Args: + layer_name: Name of the layer we are creating the LinearMethod for. + Returns + sq_linear_method: SmoothQuantLinearMethod with the right SQFormat. + """ + # Note: AutoSmoothQuant Serialization is not very good yet. + # + # It looks like the following (which does not map to layer names in the model): + # { + # "qkv": "per-tensor", + # "out": "per-token", + # "fc1": "per-tensor", + # "fc2": "per-token" + # } + # + # So, this is a hack for llama now. But with the SparseMLConfig, we can make robust, + # where we actually use the layer_name in the model to look up what the format is + # based on the config. + # + # What it would actually look like: + # layer_config is None + # for supported_key in SUPPORTED_LAYER_KEYS: + # if supported_key in layer_name: + # sq_format = self.layer_mapping[lookup_key] + # return get_sq_format_cls(sq_format)() + + HACKED_REMAP_FOR_LLAMA = { + "qkv": "qkv", + "o_proj": "out", + "gate_up": + "fc1", "down": "fc2", + } + + for match_key, lookup_key in HACKED_REMAP_FOR_LLAMA.items(): + if match_key in layer_name: + sq_format = self.sq_config.layer_format_map[lookup_key] + return get_sq_format_cls(sq_format)() + + raise ValueError + + def create_weights(self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + del input_size, output_size + + # Statically Quantized Weights. + weight = Parameter( + torch.empty( + sum(output_sizes_per_partition), + input_size_per_partition, + device="cuda", dtype=torch.int8, + ), requires_grad=False, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + + # Static scale for each logical weight (e.g. 3 for QKV). + dequant_scale = Parameter( + torch.empty( + len(output_sizes_per_partition), + device='cuda', dtype=params_dtype, + ), requires_grad=False + ) + set_weight_attrs(dequant_scale, { + "shard_splitter": self.scales_shard_splitter, + }) + + return { + "weight": weight, + "dequant_scale": dequant_scale, + "logical_widths": output_sizes_per_partition, + "sq_format": self.get_layer_format(layer_name) + } + + def _quantize(self, + x: torch.Tensor, + sq_format: SmoothQuantFormat) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Quantize activations. + + Args: + x: Activation at floating point precision. + Returns: + x_q: Quantized activation at INT8 + activation_scales: Optional dynamic scales for each token. + """ + x_q = torch.empty_like(x, dtype=torch.int8) + x_q, activation_scales = sq_format.quantize_op(x, x_q) + return x_q, activation_scales + + def _dequantize(self, + x_q: torch.Tensor, + dynamic_scales: Optional[torch.Tensor], + static_scales: torch.Tensor, + logical_widths: List[int], + dtype: torch.dtype, + sq_format: SmoothQuantFormat) -> torch.Tensor: + """Dequantize activations. + + Args: + x_q: quantized activations. + dynamic_scales: Optional dynamic scales. + static_scales: Static dequantization scales. + logical_widths: Width of each logical activation (for QKV case). + dtype: Datatype to dequantize to. + Returns: + x_dq: dequantized activation at output_dtype precision + """ + # Split X_q and X_dq buffer into logical activations (for QKV case). + x_q_split = x_q.split(logical_widths, dim=-1) + x_dq = torch.empty_like(x_q, dtype=dtype) + x_dq_split = x_dq.split(logical_widths, dim=-1) + # Dequantize in place and return. + sq_format.dequantize_op(x_q_split, x_dq_split, dynamic_scales, static_scales) + return x_dq + + + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward method. Computes Q --> GEMM --> DQ. + + Args: + weigths: Dictionary of weights, scales, and metadata. + x: Input in floating point precision. + bias: Optional bias. + Returns: + a_dq: Dequantized activation at floating point precision. + """ + if bias is not None: + raise NotImplementedError + weight_q = weights["weight"] + static_scales = weights["dequant_scale"] + logical_widths = weights["logical_widths"] + sq_format = weights["sq_format"] + + # Q + x_q, activation_scales = self._quantize(x, sq_format) + + # GEMM + x_q = x_q.view(-1, x_q.shape[-1]) + a_q = torch.empty((x_q.shape[0], weight_q.shape[0]), dtype=torch.int32, device="cuda") + self.i8cugemm.linear_a8_w8_o32_(x_q, weight_q, a_q) + a_q = a_q.view(*x_q.shape[:-1], -1) + + # DQ + return self._dequantize(a_q, activation_scales, static_scales, logical_widths, x.dtype, sq_format) diff --git a/vllm/model_executor/layers/quantization/smoothquant/formats.py b/vllm/model_executor/layers/quantization/smoothquant/formats.py new file mode 100644 index 000000000000..b8ddd642c888 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/formats.py @@ -0,0 +1,100 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Type + +import torch + +from vllm._C import ops + + +class SmoothQuantFormat(ABC): + @abstractmethod + def dequantize_op(self, + x_qs: List[torch.Tensor], + x_dqs: List[torch.Tensor], + dynamic_scales: Optional[torch.Tensor], + static_scales: torch.Tensor) -> None: + """Dequantize the activations. x_dq is updated in place. + + Args: + x_qs: List of N quantized activations. + x_dqs: List of N buffers to fill with dequantized values. + dynamic_scales: Optional dynamic scales for dequantization. + static_scales: Static scales for dequantization. N values. + """ + raise NotImplementedError + + + @abstractmethod + def quantize_op(self, + x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Quantize the input and (optionally compute dequant scales). + + Args: + x: Input data in floating point format. + x_q: Buffer for quantized inputs. + Returns: + x_q: Quantized input. + activation_scales: Optional dynamic scales for the activations. + """ + raise NotImplementedError + + +class SmoothQuantDynamicPerToken(SmoothQuantFormat): + def dequantize_op(self, + x_qs: List[torch.Tensor], + x_dqs: List[torch.Tensor], + dynamic_scales: Optional[torch.Tensor], + static_scales: torch.Tensor) -> None: + """Notes: + dynamic_scales: N scales for N tokens in the activation. + static_scales: K scales for K logical activations (equals just w_scale). + """ + if dynamic_scales is None: + raise ValueError + + # Dequantize each logical activation. + # TODO: test this for case when logical_widths > 1 (may need to reshape) + for x_dq, x_q, dynamic_scale, static_scale in zip( + x_dqs, x_qs, dynamic_scales, static_scales): + + # Dequantize (updates x_dq in place). + ops.dequant(x_dq, x_q, dynamic_scale, static_scale) + + + def quantize_op(self, + x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Notes: + Returns quantized activaiton and dynamic activation scales. + """ + activation_scales = torch.empty(x.numel() // x.shape[-1], dtype=x.dtype, device=x.device) + ops.quant(x_q, x, activation_scales) + return x_q, activation_scales + + +class SmoothQuantStaticPerTensor(SmoothQuantFormat): + def dequantize_op(self, + x_qs: List[torch.Tensor], + x_dqs: List[torch.Tensor], + dynamic_scales: Optional[torch.Tensor], + static_scales: torch.Tensor) -> None: + """Notes: + dynamic_scales: None + static_scales: K scales for K logical activations (equals w_scale * a_scale). + """ + if dynamic_scales is not None: + raise ValueError + + # Dequantize each logical activation. + for xdq, xq, static_scale in zip(x_dqs, x_qs, static_scales): + ops.dequant(xdq, xq, static_scale) + + def quantize_op(self, + x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, None]: + """Notes: + Returns quantized activaiton and no dynamic scales. + """ + ops.quant(x_q, x, 1.0) + return x_q, None diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1..893e6781089d 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,10 +68,17 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def create_weights( + self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + del layer_name, input_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index b191dc4009b5..de7910c4860b 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -46,10 +46,6 @@ def _get_model_architecture( def get_architecture_class_name(model_config: ModelConfig) -> str: return _get_model_architecture(model_config)[1] -def _is_support_smoothquant(model_config: ModelConfig) -> bool: - architectures = getattr(model_config.hf_config, "architectures", []) - supported_archs = ModelRegistry.get_supported_smoothquant_archs() - return any(arch in supported_archs for arch in architectures) def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: @@ -82,10 +78,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, # Create a model instance. # The weights will be initialized as empty tensors. with torch.device(device_config.device): - if _is_support_smoothquant(model_config): - model = model_class(model_config.hf_config, linear_method, - quant_config) - elif hasattr(model_class, "supported_lora_modules"): + if hasattr(model_class, "supported_lora_modules"): model = model_class(model_config.hf_config, linear_method, lora_config) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1fffbc5fa30c..0b6c75705764 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,21 +29,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.model_executor.layers.quantization.smoothquant import ( - Int8GEMM, - SQLinearMethod, - SQLinearMethodQKV, - SQLinearMethodOProj, - SQLinearMethodGateUpProj, - SQLinearMethodDownProj) -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) - -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -58,6 +49,7 @@ from vllm.sequence import SamplerOutput from vllm.utils import is_hip + class LlamaMLP(nn.Module): def __init__( @@ -65,43 +57,25 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + parent_name: str, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.hidden_size = hidden_size - self.use_int8 = quant_config is not None and quant_config.get_name( - ) == "smoothquant" - - gate_up_linear_method = linear_method - if self.use_int8: - # override gate_up linear method - assert isinstance(linear_method, SQLinearMethod) - gate_up_linear_method = SQLinearMethodGateUpProj( - gemm=Int8GEMM, - quant_dtype=torch.int8, - dequant_dtype=torch.float) self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, bias=False, - linear_method=gate_up_linear_method) + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") - - down_proj_linear_method = linear_method - if self.use_int8: - # override gate_up linear method - assert isinstance(linear_method, SQLinearMethod) - down_proj_linear_method = SQLinearMethodDownProj( - gemm=Int8GEMM, - quant_dtype=torch.int8, - dequant_dtype=torch.float) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=down_proj_linear_method) - self.act_fn = SiluAndMul() def forward(self, x): @@ -110,6 +84,7 @@ def forward(self, x): x, _ = self.down_proj(x) return x + class LlamaAttention(nn.Module): def __init__( @@ -117,42 +92,36 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, + parent_name: str, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, ) -> None: super().__init__() self.hidden_size = hidden_size - self.tp_size = get_tensor_model_parallel_world_size() + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads - assert self.total_num_heads % self.tp_size == 0 - self.num_heads = self.total_num_heads // self.tp_size + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads - self.default_dtype = torch.get_default_dtype() - - if self.total_num_kv_heads >= self.tp_size: + if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % self.tp_size == 0 + assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert self.tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.use_int8 = quant_config is not None and quant_config.get_name( - ) == "smoothquant" - # Needs to be ironed out!! - self.use_per_token_quant = self.use_int8 # This will be overwritten by model initialization if we are using it. # N.B. currently we only support per tensor scalar scaling factors @@ -163,55 +132,35 @@ def __init__( # scaling_factor = tensor_amax / FPtype_max self.kv_scale = 1.0 - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - - qkv_linear_method = linear_method - if self.use_int8: - # override qkv linear method - assert isinstance(linear_method, SQLinearMethod) - qkv_linear_method = SQLinearMethodQKV( - gemm=Int8GEMM, - qkv_sizes=(self.q_size, self.kv_size, self.kv_size), - quant_dtype=torch.int8, - dequant_dtype=self.rotary_emb.cos_sin_cache.dtype) self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, - linear_method=qkv_linear_method, + linear_method=linear_method, ) - - o_proj_linear_method = linear_method - if self.use_int8: - # override o_proj linear method - assert isinstance(linear_method, SQLinearMethod) - o_proj_linear_method = SQLinearMethodOProj( - gemm=Int8GEMM, - use_per_token_quant=True, # TODO (varun) : Read from config - quant_dtype = torch.int8, - dequant_dtype= torch.float) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=bias, - linear_method=o_proj_linear_method, + linear_method=linear_method, ) - self.attn = Attention( - self.num_heads, + self.rotary_emb = get_rope( self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=sliding_window) def forward( self, @@ -228,19 +177,17 @@ def forward( output, _ = self.o_proj(attn_output) return output + class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, + parent_name: str, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size - self.use_int8 = quant_config is not None and quant_config.get_name( - ) == "smoothquant" - self.tp_size = get_tensor_model_parallel_world_size() rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", @@ -254,8 +201,8 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + parent_name=f"{parent_name}.self_attn", linear_method=linear_method, - quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, ) @@ -263,8 +210,8 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + parent_name=f"{parent_name}.mlp", linear_method=linear_method, - quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -279,7 +226,6 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention if residual is None: residual = hidden_states @@ -287,7 +233,6 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -296,8 +241,8 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, - residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -308,8 +253,7 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + lora_config: Optional[LoRAConfig] = None ) -> None: super().__init__() self.config = config @@ -324,8 +268,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method, quant_config) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config, + parent_name=f"model.layers.{idx}", + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -390,15 +336,12 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.quant_config = quant_config - self.model = LlamaModel(config, linear_method, lora_config=lora_config, - quant_config = quant_config) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -447,9 +390,6 @@ def load_weights(self, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - # For SmoothQuant - int8_fusion = self.quant_config is not None and \ - self.quant_config.get_name() == "smoothquant" stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -459,8 +399,13 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name @@ -468,26 +413,6 @@ def load_weights(self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - # bias is useless for llama - if "bias" in name: - continue - # load dequant scale for qkv_proj and gate_up_proj - if int8_fusion: - is_fusion_scale = False - if "scale" in name: - for (param_name, weight_name, _) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - prefix = weight_name.split('_')[0] - suffix = name.split('.')[-1] - new_name = prefix + '_' + suffix - param = params_dict[name.replace(suffix, new_name)] - param.copy_(loaded_weight) - is_fusion_scale = True - break - if is_fusion_scale: - continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue