diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 25cfa411321c..55ab06dcb85a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -727,11 +727,12 @@ def _load_state_dict_into_meta_model( device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) is_quantized = hf_quantizer is not None - is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in { + is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in { QuantizationMethod.HQQ, QuantizationMethod.BITS_AND_BYTES, + QuantizationMethod.TORCHAO, } - is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb + is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_ao file_pointer = None if is_meta_state_dict: file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) @@ -873,7 +874,7 @@ def load_shard_file(args): shard_file, state_dict, disk_only_shard_files, - is_hqq_or_bnb, + is_hqq_or_bnb_or_ao, is_quantized, device_map, hf_quantizer, @@ -899,7 +900,7 @@ def load_shard_file(args): map_location = "cpu" if ( shard_file.endswith(".safetensors") - and not is_hqq_or_bnb + and not is_hqq_or_bnb_or_ao and not (is_deepspeed_zero3_enabled() and not is_quantized) ): map_location = "meta" @@ -922,6 +923,13 @@ def load_shard_file(args): # Fix the key names state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} + metadata = None + if shard_file.endswith(".safetensors") and is_safetensors_available(): + with safe_open(shard_file, framework="pt") as f: + metadata = f.metadata() + + if hf_quantizer: + state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, metadata) error_msgs = [] @@ -5277,9 +5285,10 @@ def _load_pretrained_model( QuantizationMethod.HQQ, QuantizationMethod.QUARK, } - is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in { + is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in { QuantizationMethod.HQQ, QuantizationMethod.BITS_AND_BYTES, + QuantizationMethod.TORCHAO, } # Get all the keys of the state dicts that we have to initialize the model @@ -5451,7 +5460,7 @@ def _load_pretrained_model( shard_file, state_dict, disk_only_shard_files, - is_hqq_or_bnb, + is_hqq_or_bnb_or_ao, is_quantized, device_map, hf_quantizer, diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 323faa9c17e2..8710e1426a8e 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -342,6 +342,10 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False): """Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization""" return None, {} + def update_state_dict_with_metadata(self, state_dict, metadata): + """Update state dict with metadata. Default behaviour returns state_dict""" + return state_dict + @abstractmethod def _process_model_before_weight_loading(self, model, **kwargs): ... diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index cba023a7d811..344c9e3534ed 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -35,6 +35,17 @@ import torch import torch.nn as nn +if is_torchao_available(): + import torchao + + if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + + logger = logging.get_logger(__name__) @@ -81,6 +92,15 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" +if is_torchao_available(): + SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [ + torchao.quantization.Float8WeightOnlyConfig, + torchao.quantization.Float8DynamicActivationFloat8WeightConfig, + ] + + TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao")) + + class TorchAoHfQuantizer(HfQuantizer): """ Quantizer for torchao: https://github.com/pytorch/ao/ @@ -137,6 +157,21 @@ def update_dtype(self, dtype): dtype = torch.float32 return dtype + def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool] = False): + """ + If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with + the safetensors format. + """ + if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization: + if TORCHAO_VERSION >= version.parse("0.14.0"): + return flatten_tensor_state_dict(model.state_dict()) + else: + raise RuntimeError( + f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}" + ) + else: + return super().get_state_dict_and_metadata(model) + def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype": if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): from accelerate.utils import CustomDtype @@ -279,6 +314,16 @@ def create_quantized_param( quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + def update_state_dict_with_metadata(self, state_dict, metadata): + """ + If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict + from the provided state_dict and metadata. + """ + if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata): + return unflatten_tensor_state_dict(state_dict, metadata) + else: + return super().update_state_dict_with_metadata(state_dict, metadata) + def _process_model_after_weight_loading(self, model, **kwargs): """No process required for torchao quantized model""" if self.quantization_config.quant_type == "autoquant": @@ -297,10 +342,17 @@ def _process_model_after_weight_loading(self, model, **kwargs): def is_serializable(self, safe_serialization=None) -> bool: if safe_serialization: - logger.warning( - "torchao quantized model does not support safe serialization, please set `safe_serialization` to False" - ) - return False + _is_torchao_serializable = type( + self.quantization_config.quant_type + ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0") + if not _is_torchao_serializable: + logger.warning( + f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \ + and torchao version >= 0.14.0, please set `safe_serialization` to False for \ + {type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}." + ) + return _is_torchao_serializable + _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( "0.25.0" ) diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 0ea22ae08df0..1ddc2de0801f 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -18,6 +18,7 @@ import unittest from packaging import version +from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig from transformers.testing_utils import ( @@ -37,6 +38,8 @@ import torch if is_torchao_available(): + import torchao + # renamed in torchao 0.7.0, please install the latest torchao from torchao.dtypes import ( AffineQuantizedTensor, @@ -135,7 +138,7 @@ class TorchAoTest(unittest.TestCase): model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" device = "cpu" quant_scheme_kwargs = ( - {"group_size": 32, "layout": Int4CPULayout()} + {"group_size": 32, "layout": Int4CPULayout(), "version": 1} if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") else {"group_size": 32} ) @@ -225,6 +228,7 @@ def test_include_input_output_embeddings(self): weight_dtype=weight_dtype, granularity=granularity, mapping_type=mapping_type, + version=1, ) config = ModuleFqnToConfig( {"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config} @@ -277,7 +281,7 @@ def test_per_module_config_skip(self): @require_torch_accelerator class TorchAoAcceleratorTest(TorchAoTest): device = torch_device - quant_scheme_kwargs = {"group_size": 32} + quant_scheme_kwargs = {"group_size": 32, "version": 1} # called only once for all test in this class @classmethod @@ -327,7 +331,7 @@ def test_int4wo_offload(self): "lm_head": 0, } - quant_config = TorchAoConfig("int4_weight_only", group_size=32) + quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs) quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, @@ -399,7 +403,7 @@ def test_autoquant(self): check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj) - EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)" + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" output = quantized_model.generate( **input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static" ) @@ -414,7 +418,7 @@ class TorchAoSerializationTest(unittest.TestCase): model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" quant_scheme = "int4_weight_only" quant_scheme_kwargs = ( - {"group_size": 32, "layout": Int4CPULayout()} + {"group_size": 32, "layout": Int4CPULayout(), "version": 1} if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") else {"group_size": 32} ) @@ -447,13 +451,13 @@ def test_original_model_expected_output(self): self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - def check_serialization_expected_output(self, device, expected_output): + def check_serialization_expected_output(self, device, expected_output, safe_serialization=False): """ Test if we can serialize and load/infer the model again on the same device """ dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto" with tempfile.TemporaryDirectory() as tmpdirname: - self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) + self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization) loaded_quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, dtype=dtype, device_map=device) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device) @@ -464,6 +468,48 @@ def test_serialization_expected_output(self): self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT) +@require_torchao +@require_torchao_version_greater_or_equal("0.14.0") +class TorchAoSafeSerializationTest(TorchAoSerializationTest): + # called only once for all test in this class + @classmethod + def setUpClass(cls): + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + if hasattr(self, "quantized_model"): + del self.quantized_model + gc.collect() + + test_params = ( + [ + ( + torchao.quantization.Float8DynamicActivationFloat8WeightConfig(), + "What are we having for dinner?\n\nJess: (smiling) I", + ), + (torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"), + ] + if is_torchao_available() + else [] + ) + + @parameterized.expand(test_params, skip_on_empty=True) + def test_serialization_expected_output(self, config, expected_output): + device = "cuda" + self.quant_config = TorchAoConfig(config) + self.quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + dtype=torch.bfloat16, + device_map=device, + quantization_config=self.quant_config, + ) + self.check_serialization_expected_output(device, expected_output, safe_serialization=True) + + class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} @@ -500,7 +546,7 @@ def test_serialization_expected_output_on_accelerator(self): @require_torch_accelerator class TorchAoSerializationAcceleratorTest(TorchAoSerializationTest): - quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} + quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32, "version": 1} device = f"{torch_device}:0" # called only once for all test in this class