diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 080181e317a5..9559088a1039 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -115,6 +115,8 @@ def convert( source_keys: list[str], target_keys: list[str], full_layer_name: str, + model, + missing_keys, config, **kwargs, ) -> dict[str, list[torch.Tensor]]: @@ -138,6 +140,8 @@ def convert( source_keys: list[str], target_keys: list[str], full_layer_name: str, + model, + missing_keys, config, ) -> dict[str, list[torch.Tensor]]: tensors = next(iter(value.values())) @@ -163,6 +167,8 @@ def convert( source_keys: list[str], target_keys: list[str], full_layer_name: str, + model, + missing_keys, config, ) -> dict[str, torch.Tensor]: if len(target_keys) != 1: @@ -191,6 +197,8 @@ def convert( source_keys: list[str], target_keys: list[str], full_layer_name: str, + model, + missing_keys, config, ) -> dict[str, torch.Tensor]: merged: dict[str, torch.Tensor] = {} @@ -220,6 +228,8 @@ def convert( source_keys: list[str], target_keys: list[str], full_layer_name: str, + model, + missing_keys, config, ) -> dict[str, list[torch.Tensor]]: if len(value) != len(self.sizes): @@ -258,6 +268,8 @@ def convert( source_keys: list[str], target_keys: list[str], full_layer_name: str, + model, + missing_keys, config, ) -> dict[str, list[torch.Tensor]]: self.config = config @@ -298,21 +310,28 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu class WeightRenaming(WeightTransform): # Special case of WeightTransform that only renames keys without any conversion. - def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None): + def convert( + self, + layer_name: str, + model=None, + config=None, + hf_quantizer=None, + missing_keys: Optional[MutableSet[str]] = None, + ): misc = {} for pattern, futures in self.collected_tensors.items(): self.collected_tensors[pattern] = [future.result() for future in futures] collected_tensors = self.collected_tensors - if quantizer is not None and self.quantization_operation is not None: + if hf_quantizer is not None and self.quantization_operation is not None: with log_to_misc(layer_name, misc, (self.collected_tensors, layer_name), self.quantization_operation): collected_tensors = self.quantization_operation.convert( self.collected_tensors, source_keys=self.source_keys, target_keys=self.target_keys, full_layer_name=layer_name, + model=model, config=config, - quant_config=quantizer.quantization_config, missing_keys=missing_keys, ) @@ -332,7 +351,14 @@ def __post_init__(self): if not self.operations: raise ValueError("WeightConverter requires at least one operation.") - def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None): + def convert( + self, + layer_name: str, + model=None, + config=None, + hf_quantizer=None, + missing_keys: Optional[MutableSet[str]] = None, + ): misc = {} for pattern, futures in self.collected_tensors.items(): self.collected_tensors[pattern] = [future.result() for future in futures] @@ -345,9 +371,11 @@ def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Op source_keys=self.source_keys, target_keys=self.target_keys, full_layer_name=layer_name, + model=model, config=config, + missing_keys=missing_keys, ) - if quantizer is not None and self.quantization_operation is not None: + if hf_quantizer is not None and self.quantization_operation is not None: with log_to_misc(layer_name, misc, (collected_tensors, layer_name), self.quantization_operation): collected_tensors = self.quantization_operation.convert( collected_tensors, @@ -355,7 +383,7 @@ def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Op target_keys=self.target_keys, full_layer_name=layer_name, config=config, - quant_config=quantizer.quantization_config, + model=model, missing_keys=missing_keys, ) return collected_tensors, misc @@ -626,7 +654,6 @@ def convert_and_load_state_dict_in_model( ``` """ - prefix = model.base_model_prefix tp_plan = tp_plan or {} device_map = device_map or {"": "cpu"} @@ -750,7 +777,11 @@ def convert_and_load_state_dict_in_model( pbar.refresh() try: realized_value, misc = mapping.convert( - first_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys + first_param_name, + model=model, + config=model.config, + hf_quantizer=hf_quantizer, + missing_keys=missing_keys, ) for target_name, param in realized_value.items(): param = param[0] if isinstance(param, list) else param diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index f201ae3970be..3c1f9fd49c90 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -241,7 +241,7 @@ def all_tensors(): if name in tied_keys: continue if hf_quantizer is not None: - dtype_size = hf_quantizer.param_element_size(model, name) + dtype_size = hf_quantizer.param_element_size(model, name, param) else: dtype_size = param.element_size() size = param.numel() * dtype_size diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py index a68b19ff7a1f..1d08e4a7074e 100644 --- a/src/transformers/integrations/bitsandbytes.py +++ b/src/transformers/integrations/bitsandbytes.py @@ -36,7 +36,11 @@ def __init__(self, hf_quantizer): self.hf_quantizer = hf_quantizer def convert( - self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs + self, + input_dict: dict[str, list[torch.Tensor]], + model: Optional[torch.nn.Module] = None, + missing_keys=None, + **kwargs, ) -> dict[str, torch.Tensor]: """ we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor. @@ -59,6 +63,7 @@ def convert( # remove missing keys that were create when initializing Params4bit for key in new_value.quant_state.as_dict(packed=True).keys(): missing_keys.discard(f"{full_name}.{key}") + module._is_hf_initialized = True return {target_key: new_value} else: module_name = target_key.rsplit(".", 1)[0] @@ -77,6 +82,7 @@ def convert( device=value.device, module=module, ) + module._is_hf_initialized = True del self.hf_quantizer.param_quant_stats[module_name] return {target_key: new_value} return {} @@ -87,7 +93,11 @@ def __init__(self, hf_quantizer): self.hf_quantizer = hf_quantizer def convert( - self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs + self, + input_dict: dict[str, list[torch.Tensor]], + model: Optional[torch.nn.Module] = None, + missing_keys=None, + **kwargs, ) -> dict[str, torch.Tensor]: target_key, value = tuple(input_dict.items())[0] value = value[0] if isinstance(value, list) else value diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 2ef2ea5467cd..3c1b99651e4e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -20,9 +20,13 @@ from functools import partial, reduce from typing import Optional -import torch -import torch.distributed as dist -from torch import nn +from ..utils.import_utils import is_torch_available + + +if is_torch_available(): + import torch + import torch.distributed as dist + from torch import nn from ..distributed import DistributedConfig from ..utils import is_torch_greater_or_equal, logging @@ -31,12 +35,12 @@ logger = logging.get_logger(__name__) -# Cache this result has it's a C FFI call which can be pretty time-consuming -_torch_distributed_available = torch.distributed.is_available() - +if is_torch_available(): + # Cache this result has it's a C FFI call which can be pretty time-consuming + _torch_distributed_available = torch.distributed.is_available() -if is_torch_greater_or_equal("2.5") and _torch_distributed_available: - from torch.distributed.tensor import DTensor, Placement, Replicate, Shard + if is_torch_greater_or_equal("2.5") and _torch_distributed_available: + from torch.distributed.tensor import DTensor, Placement, Replicate, Shard def initialize_tensor_parallelism( @@ -169,19 +173,20 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig return None -str_to_dtype = { - "BOOL": torch.bool, - "U8": torch.uint8, - "I8": torch.int8, - "I16": torch.int16, - "F16": torch.float16, - "BF16": torch.bfloat16, - "I32": torch.int32, - "F32": torch.float32, - "F64": torch.float64, - "I64": torch.int64, - "F8_E4M3": torch.float8_e4m3fn, -} +if is_torch_available(): + str_to_dtype = { + "BOOL": torch.bool, + "U8": torch.uint8, + "I8": torch.int8, + "I16": torch.int16, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I32": torch.int32, + "F32": torch.float32, + "F64": torch.float64, + "I64": torch.int64, + "F8_E4M3": torch.float8_e4m3fn, + } def get_packed_weights(param, empty_param, device_mesh, rank, dim): diff --git a/src/transformers/integrations/torchao.py b/src/transformers/integrations/torchao.py new file mode 100644 index 000000000000..3a1fdb0d407e --- /dev/null +++ b/src/transformers/integrations/torchao.py @@ -0,0 +1,274 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +import re +import types +from typing import Optional + +import torch +from packaging import version + +from transformers.utils import logging +from transformers.utils.import_utils import is_torch_available, is_torchao_available + + +if is_torch_available(): + from ..core_model_loading import ConversionOps +from ..quantizers.quantizers_utils import get_module_from_name + + +if is_torchao_available(): + TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao")) + if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"): + from torchao.prototype.safetensors.safetensors_support import ( + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + +logger = logging.get_logger(__name__) + + +def fuzzy_match_size(config_name: str) -> Optional[str]: + """ + Extract the size digit from strings like "4weight", "8weight". + Returns the digit as an integer if found, otherwise None. + """ + config_name = config_name.lower() + + str_match = re.search(r"(\d)weight", config_name) + + if str_match: + return str_match.group(1) + + return None + + +def _quantization_type(weight): + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + + if isinstance(weight, AffineQuantizedTensor): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + + if isinstance(weight, LinearActivationQuantizedTensor): + return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" + + +def _linear_extra_repr(self): + weight = _quantization_type(self.weight) + if weight is None: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" + else: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" + + +class TorchAoQuantize(ConversionOps): + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict: dict[str, torch.Tensor], + model: Optional[torch.nn.Module] = None, + full_layer_name: str | None = None, + missing_keys=None, + **kwargs, + ) -> dict[str, torch.Tensor]: + from torchao.quantization import quantize_ + + _, value = tuple(input_dict.items())[0] + value = value[0] if isinstance(value, list) else value + + module, tensor_name = get_module_from_name(model, full_layer_name) + + module._parameters[tensor_name] = torch.nn.Parameter(value, requires_grad=value.requires_grad).to(value.device) + # if we are quantizing tied parameters, to avoid tying the quantized weights + # the correct order to do it is + # 1. load the weight to model + # 2. run tie_weights to populate the weights + # 3. quantize + input_embed = model.get_input_embeddings() + is_embedding_param = id(module) == id(input_embed) + untie_embedding_weights = self.hf_quantizer.quantization_config.untie_embedding_weights + + if untie_embedding_weights and is_embedding_param: + setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False) + + # handle FqnToConfig, introduced in torchao 0.15.0+ + if self.hf_quantizer.quantization_config._get_ao_version() >= version.Version("0.15.0"): + from torchao.quantization import FqnToConfig + + config = self.hf_quantizer.quantization_config.get_apply_tensor_subclass() + if isinstance(config, FqnToConfig): + module_fqn, top_level_param_name = full_layer_name.rsplit(".", 1) + c = None + if full_layer_name in config.fqn_to_config: + assert not module_fqn.startswith("re:"), ( + "param fqn should not start with`re:`, which is used for specifying regex" + ) + c = config.module_fqn_to_config[full_layer_name] + elif module_fqn in config.fqn_to_config: + assert not module_fqn.startswith("re:"), ( + "module fqn should not start with`re:`, which is used for specifying regex" + ) + c = config.module_fqn_to_config[module_fqn] + # regex match module and param + else: + for maybe_module_fqn_pattern in config.fqn_to_config: + # if key doesn't start with re, it is an exact fqn key, so we don't regex match + if not maybe_module_fqn_pattern.startswith("re:"): + continue + # see if param matches first + elif re.fullmatch(maybe_module_fqn_pattern[3:], full_layer_name): + c = config.module_fqn_to_config[maybe_module_fqn_pattern] + break + elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn): + # we'll apply the config for first fully matched pattern + c = config.module_fqn_to_config[maybe_module_fqn_pattern] + break + else: + c = config.module_fqn_to_config.get("_default", None) + + if c is not None: + if top_level_param_name == "weight": + if is_embedding_param and untie_embedding_weights: + lm_head = module.weight.clone() + # we can apply the module config directly + quantize_(module, c, (lambda x, fqn: True)) + missing_keys.discard(full_layer_name) + module._is_hf_initialized = True + return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {} + else: + # need to apply to custom param name + custom_param_fqn_config = FqnToConfig({top_level_param_name: c}) + quantize_(module, custom_param_fqn_config, filter_fn=None) + missing_keys.discard(full_layer_name) + module._is_hf_initialized = True + return {} + return {full_layer_name: value} + + # handle ModuleFqnToConfig, introduced in torchao 0.12.0+ + # TODO deprecate this when we deprecate ModuleFqnToConfig + elif self.hf_quantizer.quantization_config._get_ao_version() >= version.Version("0.12.0"): + from torchao.quantization import ModuleFqnToConfig + + config = self.hf_quantizer.quantization_config.get_apply_tensor_subclass() + if isinstance(config, ModuleFqnToConfig): + module_fqn, _ = full_layer_name.rsplit(".", 1) + c = None + if module_fqn in config.module_fqn_to_config: + assert not module_fqn.startswith("re:"), ( + "module fqn should not start with`re:`, which is used for specifying regex" + ) + c = config.module_fqn_to_config[module_fqn] + else: + for maybe_module_fqn_pattern in config.module_fqn_to_config: + if not maybe_module_fqn_pattern.startswith("re:"): + continue + elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn): + # we'll apply the config for first fully matched pattern + c = config.module_fqn_to_config[maybe_module_fqn_pattern] + break + else: + c = config.module_fqn_to_config.get("_default", None) + if c is not None: + # filter_fn: not filtering out any modules + if is_embedding_param and untie_embedding_weights: + lm_head = module.weight.clone() + quantize_(module, c, filter_fn=lambda x, fqn: True) + missing_keys.discard(full_layer_name) + module._is_hf_initialized = True + return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {} + + return {full_layer_name: value} + + if is_embedding_param and untie_embedding_weights: + lm_head = module.weight.clone() + quantize_(module, self.hf_quantizer.quantization_config.get_apply_tensor_subclass()) + missing_keys.discard(full_layer_name) + module._is_hf_initialized = True + return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {} + + +class TorchAoDeserialize(ConversionOps): + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict: dict[str, torch.Tensor], + model: Optional[torch.nn.Module] = None, + full_layer_name: str | None = None, + missing_keys=None, + **kwargs, + ) -> dict[str, torch.Tensor]: + if isinstance(self.hf_quantizer.quantization_config.quant_type, str): + is_int_4 = "int4" in self.hf_quantizer.quantization_config.quant_type + else: + config_name = self.hf_quantizer.quantization_config.quant_type.__class__.__name__ + is_int_4 = fuzzy_match_size(config_name) == "4" + + # Simple case if we gather layermsnorm weights, we can just return the value since they are not quantized + if "weight:_data" in input_dict.keys(): + value = ( + input_dict["weight:_data"][0] + if isinstance(input_dict["weight:_data"], list) + else input_dict["weight:_data"] + ) + return {full_layer_name: value} + + is_unsafe_serialization = ":" not in list(input_dict.keys())[0] + + param_data = {} + if is_unsafe_serialization: + if isinstance(input_dict["weight"], list): + weight = input_dict["weight"][0] + else: + weight = input_dict["weight"] + else: + if isinstance(input_dict["weight:qdata"], list): + param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"][0] + else: + param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"] + + if isinstance(input_dict["weight:scale"], list): + param_data[f"{full_layer_name}:scale"] = input_dict["weight:scale"][0] + else: + param_data[f"{full_layer_name}:scale"] = input_dict["weight:scale"] + + if is_int_4: + if isinstance(input_dict["weight:zero_point"], list): + param_data[f"{full_layer_name}:zero_point"] = input_dict["weight:zero_point"][0] + else: + param_data[f"{full_layer_name}:zero_point"] = input_dict["weight:zero_point"] + + # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was + # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either + if is_unsafe_serialization: + return {full_layer_name: weight} + # Sanity check for the new serialization format + elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)): + # print("metadata", self.hf_quantizer.metadata) + raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") + + new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name] + + module, _ = get_module_from_name(model, full_layer_name) + # Add repr to the module + if isinstance(module, torch.nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + + return {full_layer_name: new_param} diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e1df5de3ae35..748d7af639af 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3889,6 +3889,8 @@ def from_pretrained( weight_conversions.extend( [WeightRenaming(source_keys=k, target_keys=v) for k, v in key_mapping.items()] ) + if hf_quantizer is not None: + weight_conversions.extend(hf_quantizer.get_weight_conversions()) if gguf_file: if hf_quantizer is not None: diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 642a2c68065f..b0a5873da303 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -165,8 +165,9 @@ def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype": """ return dtype - def param_element_size(self, model: "PreTrainedModel", param_name: str) -> float: + def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float: "Return the element size (in bytes) for `param_name`." + if self.param_needs_quantization(model, param_name): from accelerate.utils import CustomDtype @@ -179,7 +180,7 @@ def param_element_size(self, model: "PreTrainedModel", param_name: str) -> float # The value passed is actually not used when the method is overridden if (custom_dtype := self.adjust_target_dtype(torch.float16)) in mapping: return mapping[custom_dtype] - return model.get_parameter_or_buffer(param_name).element_size() + return param.element_size() def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: """ @@ -406,6 +407,9 @@ def get_quantize_ops(self): f"{self.quantization_config.quant_method} is not available yet and will be supported soon." ) + def get_weight_conversions(self): + return [] + class SequentialLlama4TextExperts(ModuleList): """ diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index d186aff620f9..777ae193db0d 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -31,6 +31,10 @@ from ..utils import is_torch_available, is_torchao_available, logging +if is_torch_available(): + from ..core_model_loading import WeightConverter + + if is_torch_available(): import torch import torch.nn as nn @@ -237,6 +241,8 @@ def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str] return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)] def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: + if self.pre_quantized: + return False if self.quantization_config.quant_type == "autoquant": return False @@ -245,30 +251,30 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, ** return False elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys): return True - else: - # we only quantize the weight of nn.Linear and nn.Embedding - module, tensor_name = get_module_from_name(model, param_name) - _QUANTIZABLE = [torch.nn.Linear] - if self.quantization_config.include_input_output_embeddings: - _QUANTIZABLE.append(torch.nn.Embedding) - - # Handle FqnToConfig, introduced in torchao 0.15.0+ - if self.quantization_config._get_ao_version() >= version.parse("0.15.0"): - from torchao.quantization import FqnToConfig, fqn_matches_fqn_config - - if isinstance(self.quantization_config.quant_type, FqnToConfig): - module_fqn, param_name_fqn = param_name.rsplit(".", 1) - if ( - fqn_matches_fqn_config(module_fqn, self.quantization_config.quant_type) - or fqn_matches_fqn_config(param_name, self.quantization_config.quant_type) - or ( - "_default" in self.quantization_config.quant_type.fqn_to_config - and isinstance(module, tuple(_QUANTIZABLE)) - ) - ): - return True - return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight" + # we only quantize the weight of nn.Linear and nn.Embedding + module, tensor_name = get_module_from_name(model, param_name) + _QUANTIZABLE = [torch.nn.Linear] + if self.quantization_config.include_input_output_embeddings: + _QUANTIZABLE.append(torch.nn.Embedding) + + # Handle FqnToConfig, introduced in torchao 0.15.0+ + if self.quantization_config._get_ao_version() >= version.parse("0.15.0"): + from torchao.quantization import FqnToConfig, fqn_matches_fqn_config + + if isinstance(self.quantization_config.quant_type, FqnToConfig): + module_fqn, param_name_fqn = param_name.rsplit(".", 1) + if ( + fqn_matches_fqn_config(module_fqn, self.quantization_config.quant_type) + or fqn_matches_fqn_config(param_name, self.quantization_config.quant_type) + or ( + "_default" in self.quantization_config.quant_type.fqn_to_config + and isinstance(module, tuple(_QUANTIZABLE)) + ) + ): + return True + + return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight" def create_quantized_param( self, @@ -530,3 +536,27 @@ def set_metadata(self, checkpoint_files: list[str]): metadata.update(metadata_) # Save it self.metadata = metadata + + def get_quantize_ops(self): + from ..integrations.torchao import TorchAoQuantize + + return TorchAoQuantize(self) + + def get_weight_conversions(self): + from ..integrations.torchao import TorchAoDeserialize + + if self.pre_quantized: + return [ + WeightConverter( + source_keys=["weight:qdata", "weight:scale", "weight:zero_point"], + target_keys="weight", + operations=[TorchAoDeserialize(self)], + ), + WeightConverter( + source_keys=["weight:_data"], + target_keys="weight", + operations=[TorchAoDeserialize(self)], + ), + # used for unsafe serialization + ] + return [] diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index a6edcbf54278..b05ee61dcb38 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -93,8 +93,8 @@ class FP8QuantizerTest(unittest.TestCase): "model.layers.13": "cpu", "model.layers.14": "cpu", "model.layers.15": "cpu", - "model.rotary_emb": "disk", - "model.norm": "disk", + "model.rotary_emb": "cpu", + "model.norm": "cpu", "lm_head": 0, } @@ -138,7 +138,7 @@ def test_quantized_model_conversion(self): for module in model.modules(): if isinstance(module, FP8Linear): nb_fp8_linear += 1 - + print(model) self.assertEqual(nb_linears - 1, nb_fp8_linear) with init_empty_weights(): @@ -209,6 +209,7 @@ def test_quantized_model_multi_accelerator(self): quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="auto", quantization_config=quantization_config ) + print("hf_device_map", quantized_model.hf_device_map) self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index d682cc57a386..06c1239ab177 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -217,6 +217,7 @@ def test_int8_dynamic_activation_int8_weight_quant(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(self.model_name) @@ -249,6 +250,7 @@ def test_include_input_output_embeddings(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) # making sure embedding is quantized self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor)) @@ -273,6 +275,7 @@ def test_per_module_config_skip(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) # making sure `model.layers.0.self_attn.q_proj` is skipped self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor)) @@ -296,6 +299,7 @@ def test_module_fqn_to_config_regex_basic(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) # making sure `model.layers.0.self_attn.q_proj` is skipped self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor)) @@ -329,6 +333,7 @@ def test_module_fqn_to_config_regex_fullmatch(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) # highest precedence is fully specified module fqn self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) @@ -362,6 +367,7 @@ def test_module_fqn_to_config_regex_precedence(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) # highest precedence is fully specified module fqn self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) @@ -396,6 +402,7 @@ def test_fqn_to_config_regex_precedence(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) @@ -427,6 +434,7 @@ def test_fqn_to_config_param_over_module_regex_precedence(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor)) @@ -457,6 +465,7 @@ def test_fqn_to_config_param_over_module_precedence(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, AffineQuantizedTensor)) self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.k_proj.weight, AffineQuantizedTensor)) @@ -487,6 +496,7 @@ def test_fqn_to_config_exact_over_regex_precedence(self): self.model_name, device_map=self.device, quantization_config=quant_config, + torch_dtype=torch.bfloat16, ) self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, AffineQuantizedTensor)) self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) @@ -576,7 +586,7 @@ def test_int4wo_offload(self): "model.layers.18": 0, "model.layers.19": "cpu", "model.layers.20": "cpu", - "model.layers.21": "disk", + "model.layers.21": "cpu", "model.norm": 0, "model.rotary_emb": 0, "lm_head": 0, @@ -587,7 +597,7 @@ def test_int4wo_offload(self): quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, - dtype=torch.bfloat16, + torch_dtype=torch.bfloat16, device_map=device_map_offload, quantization_config=quant_config, ) @@ -599,7 +609,7 @@ def test_int4wo_offload(self): EXPECTED_OUTPUTS = Expectations( { ("xpu", 3): "What are we having for dinner?\n\nJessica: (smiling)", - ("cuda", 7): "What are we having for dinner?\n- 2. What is the temperature outside", + ("cuda", 7): "What are we having for dinner?\n- 1. What is the temperature outside", } ) # fmt: on @@ -622,7 +632,7 @@ def test_int4wo_quant_multi_accelerator(self): quant_config = TorchAoConfig(config) quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, - dtype=torch.bfloat16, + torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quant_config, ) @@ -643,7 +653,7 @@ def test_autoquant(self): quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, - dtype="auto", + torch_dtype="auto", device_map=self.device, quantization_config=quant_config, ) @@ -712,7 +722,9 @@ def check_serialization_expected_output(self, device, expected_output, safe_seri dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto" with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization) - loaded_quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, dtype=dtype, device_map=device) + loaded_quantized_model = AutoModelForCausalLM.from_pretrained( + tmpdirname, dtype=dtype, device_map=device, torch_dtype=dtype, use_safetensors=safe_serialization + ) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device) output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) @@ -729,7 +741,7 @@ class TorchAoSafeSerializationTest(TorchAoSerializationTest): @classmethod def setUpClass(cls): cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) - cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" + cls.EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" # placeholder cls.quant_scheme = torchao.quantization.Float8WeightOnlyConfig()