diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index ad37056cb315..7b505a9f85f9 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -117,6 +117,7 @@ class HfQuantizer(ABC): def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): self.quantization_config = quantization_config + self.metadata = {} # -- Handle extra kwargs below -- self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) @@ -392,10 +393,6 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False): """Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization""" return None, {} - def update_state_dict_with_metadata(self, state_dict, metadata): - """Update state dict with metadata. Default behaviour returns state_dict""" - return state_dict - @abstractmethod def is_serializable(self, safe_serialization=None): ... diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 72364db59a13..9dddfba3e646 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -14,7 +14,6 @@ import importlib import re import types -from collections import defaultdict from typing import TYPE_CHECKING, Optional, Union from packaging import version @@ -38,14 +37,14 @@ if is_torchao_available(): import torchao - if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"): + if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"): + from torchao.prototype.awq import AWQConfig from torchao.prototype.safetensors.safetensors_support import ( flatten_tensor_state_dict, unflatten_tensor_state_dict, ) from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao - logger = logging.get_logger(__name__) @@ -87,6 +86,11 @@ def _linear_extra_repr(self): SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [ torchao.quantization.Float8WeightOnlyConfig, torchao.quantization.Float8DynamicActivationFloat8WeightConfig, + torchao.quantization.Int4WeightOnlyConfig, + torchao.quantization.IntxWeightOnlyConfig, + torchao.quantization.Int8DynamicActivationIntxWeightConfig, + torchao.quantization.ModuleFqnToConfig, + AWQConfig, ] TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao")) @@ -104,20 +108,6 @@ class TorchAoHfQuantizer(HfQuantizer): def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) - if isinstance(self.quantization_config.quant_type, str): - is_int_4 = "int4" in self.quantization_config.quant_type - else: - config_name = self.quantization_config.quant_type.__class__.__name__ - is_int_4 = fuzzy_match_size(config_name) == "4" - - # TODO: better way to get the serialized key names? Hard to read from torchao codebase - if is_int_4: - self.weight_ao_keys = ["qdata", "scale", "zero_point"] - else: - self.weight_ao_keys = ["qdata", "scale"] - # Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this - self.full_ao_keys = self.weight_ao_keys + ["_data"] - def validate_environment(self, *args, **kwargs): if not is_torchao_available(): raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)") @@ -168,11 +158,11 @@ def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool] the safetensors format. """ if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization: - if TORCHAO_VERSION >= version.parse("0.14.0"): + if TORCHAO_VERSION >= version.parse("0.15.0"): return flatten_tensor_state_dict(model.state_dict()) else: raise RuntimeError( - f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}" + f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}" ) else: return None, {} @@ -234,7 +224,7 @@ def _process_model_before_weight_loading( return def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]: - return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)] + return [k for k in unexpected_keys if "_weight_" not in k] def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: if self.quantization_config.quant_type == "autoquant": @@ -243,7 +233,7 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, ** # check if the param_name is not in self.modules_to_not_convert if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert): return False - elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys): + elif "_weight_" in param_name: return True else: # we only quantize the weight of nn.Linear and nn.Embedding @@ -284,42 +274,12 @@ def create_quantized_param( """ from torchao.quantization import quantize_ - full_name = param_name - # Those are the pre quantized weights - if ":" in param_name: - param_name = param_name.rsplit(":", 1)[0] module, tensor_name = get_module_from_name(model, param_name) - if self.pre_quantized: - # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was - # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either - is_unsafe_serialization = ":" not in full_name - if tensor_name == "bias" or is_unsafe_serialization: - module._parameters[tensor_name] = torch.nn.Parameter( - param_value.to(target_device), requires_grad=param_value.requires_grad - ) - return - # Sanity check for the new serialization format - elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)): - raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") - - # Save the states for later quantization when they are all gathered - if not hasattr(self, "ao_params"): - self.ao_params = defaultdict(dict) - self.ao_params[param_name].update({full_name: param_value}) - - # We are ready for quantization in this case (we retrieved all the needed keys) - if len(self.ao_params[param_name]) == len(self.weight_ao_keys): - new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name] - # Set it - module._parameters[tensor_name] = torch.nn.Parameter( - new_param.to(target_device), requires_grad=new_param.requires_grad - ) - - # Free memory - del self.ao_params[param_name] + module._parameters[tensor_name] = torch.nn.Parameter( + param_value.to(target_device), requires_grad=param_value.requires_grad + ) - # Add repr to the module if isinstance(module, nn.Linear): module.extra_repr = types.MethodType(_linear_extra_repr, module) else: @@ -430,6 +390,32 @@ def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpo def _process_model_after_weight_loading(self, model, **kwargs): """No process required for torchao quantized model""" + if TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.metadata): + _, updated_state_dict = unflatten_tensor_state_dict(model.state_dict(), self.metadata) + + weights_to_register = set(updated_state_dict.keys()) + + for name, param in list(model.named_parameters()): + module_fqn, weight_name = name.rsplit(".", 1) + module = model.get_submodule(module_fqn) + weight = getattr(module, weight_name) + + device = weight.device + requires_grad = weight.requires_grad + + if "_weight_" in weight_name: + delattr(module, weight_name) + + if name in weights_to_register: + new_param_value = updated_state_dict[name] + new_param = torch.nn.Parameter(new_param_value.to(device), requires_grad=requires_grad) + module.register_parameter(weight_name, new_param) + + weights_to_register.remove(name) + + model.load_state_dict(updated_state_dict, strict=False) + return + if self.quantization_config.quant_type == "autoquant": from torchao import autoquant from torchao.quantization import ALL_AUTOQUANT_CLASS_LIST @@ -448,11 +434,11 @@ def is_serializable(self, safe_serialization=None) -> bool: if safe_serialization: _is_torchao_serializable = type( self.quantization_config.quant_type - ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0") + ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.15.0") if not _is_torchao_serializable: logger.warning( f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \ - and torchao version >= 0.14.0, please set `safe_serialization` to False for \ + and torchao version >= 0.15.0, please set `safe_serialization` to False for \ {type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}." ) return _is_torchao_serializable diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index d682cc57a386..6a9c4cb4bb5d 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -702,7 +702,6 @@ def tearDown(self): def test_original_model_expected_output(self): input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) def check_serialization_expected_output(self, device, expected_output, safe_serialization=False): @@ -723,11 +722,12 @@ def test_serialization_expected_output(self): @require_torchao -@require_torchao_version_greater_or_equal("0.14.0") +@require_torchao_version_greater_or_equal("0.15.0") class TorchAoSafeSerializationTest(TorchAoSerializationTest): # called only once for all test in this class @classmethod def setUpClass(cls): + super().setUpClass() cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" # placeholder @@ -748,6 +748,16 @@ def tearDown(self): "What are we having for dinner?\n\nJess: (smiling) I", ), (torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"), + (Int4WeightOnlyConfig(), "What are we having for dinner?"), + ( + Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), + "What are we having for dinner?\nRed, white, and green beans,", + ), + ( + torchao.quantization.Int8DynamicActivationIntxWeightConfig(), + "What are we having for dinner?\n\nJessica: (smiling)", + ), + (torchao.quantization.IntxWeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"), ] if is_torchao_available() else []