diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fd96aacabc2a..6c7d91fd641c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -561,6 +561,7 @@ def _load_state_dict_into_meta_model( dtype=None, load_in_8bit=False, is_safetensors=False, + keep_in_fp32_modules=None, ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its @@ -611,7 +612,14 @@ def _load_state_dict_into_meta_model( # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # in int/uint/bool and not cast them. if dtype is not None and torch.is_floating_point(param): - param = param.to(dtype) + if ( + keep_in_fp32_modules is not None + and any(module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules) + and dtype == torch.float16 + ): + param = param.to(torch.float32) + else: + param = param.to(dtype) # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model if dtype is None: @@ -964,6 +972,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix main_input_name = "input_ids" _auto_class = None _no_split_modules = None + _keep_in_fp32_modules = None # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. @@ -2061,6 +2070,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Load model loading_info = None + # Keep in fp32 modules + keep_in_fp32_modules = None + use_keep_in_fp32_modules = False + if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) @@ -2259,6 +2272,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype # we also may have config.torch_dtype available, but we won't rely on it till v5 dtype_orig = None + if torch_dtype is not None: if isinstance(torch_dtype, str): if torch_dtype == "auto": @@ -2276,11 +2290,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) dtype_orig = cls._set_default_torch_dtype(torch_dtype) + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = ( + (cls._keep_in_fp32_modules is not None) and is_accelerate_available() and torch_dtype == torch.float16 + ) + if ( + (cls._keep_in_fp32_modules is not None) + and not is_accelerate_available() + and torch_dtype == torch.float16 + ): + logger.warning( + "For stability purposes, it is recommended to have accelerate installed when using this model in" + " torch.float16, please install it with `pip install accelerate`" + ) + if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = [k for k in state_dict.keys()] - if low_cpu_mem_usage: + if low_cpu_mem_usage or use_keep_in_fp32_modules: state_dict = None config.name_or_path = pretrained_model_name_or_path @@ -2299,6 +2327,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) + # Check first if we are `from_pt` + if use_keep_in_fp32_modules: + low_cpu_mem_usage = True + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + if load_in_8bit: from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear @@ -2309,6 +2344,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P modules_to_not_convert = get_keys_to_not_convert(model) else: modules_to_not_convert = load_in_8bit_skip_modules + + if not isinstance(modules_to_not_convert, list): + modules_to_not_convert = [modules_to_not_convert] + + modules_to_not_convert.extend(keep_in_fp32_modules) + model = replace_8bit_linear( model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert ) @@ -2415,6 +2456,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_state_dict=offload_state_dict, dtype=torch_dtype, load_in_8bit=load_in_8bit, + keep_in_fp32_modules=keep_in_fp32_modules, ) model.is_loaded_in_8bit = load_in_8bit @@ -2458,6 +2500,7 @@ def _load_pretrained_model( offload_state_dict=None, dtype=None, load_in_8bit=False, + keep_in_fp32_modules=None, ): is_safetensors = False if load_in_8bit: @@ -2534,11 +2577,23 @@ def _fix_key(key): if key.startswith(prefix): key = ".".join(key.split(".")[1:]) param = model_state_dict[key] + + # upcast in fp32 if any + target_dtype = dtype + if ( + keep_in_fp32_modules is not None + and dtype == torch.float16 + and any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules) + ): + target_dtype = torch.float32 + if param.device == torch.device("meta"): if not load_in_8bit: - set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) + set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)) else: - set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) + set_module_8bit_tensor_to_device( + model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype) + ) # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. if _fast_init: @@ -2548,6 +2603,12 @@ def _fix_key(key): for module in uninitialized_modules: model._init_weights(module) + # Set some modules to fp32 if any + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): + param = param.to(torch.float32) + # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model @@ -2681,6 +2742,7 @@ def _find_mismatched_keys( dtype=dtype, load_in_8bit=load_in_8bit, is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, ) error_msgs += new_error_msgs else: diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a1510af74547..96a3dd3c13fd 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -757,6 +757,7 @@ class T5PreTrainedModel(PreTrainedModel): is_parallelizable = True supports_gradient_checkpointing = True _no_split_modules = ["T5Block"] + _keep_in_fp32_modules = ["wo"] @property def dummy_inputs(self): diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index b2339efd6269..4e14dbaf77d3 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -150,7 +150,7 @@ def get_keys_to_not_convert(model): # Ignore this for base models (BertModel, GPT2Model, etc.) if (not has_tied_params) and is_base_model: - return "" + return [] # otherwise they have an attached head list_modules = list(model.named_parameters()) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 67af67e1d5c9..56ce10638d50 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -155,6 +155,13 @@ def test_device_and_dtype_assignment(self): # Check this does not throw an error _ = self.model_fp16.float() + def test_fp32_int8_conversion(self): + r""" + Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly. + """ + model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto") + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + class MixedInt8ModelClassesTest(BaseMixedInt8Test): def setUp(self): diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index ab6a039c9027..fe3ce7597bfe 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -19,7 +19,14 @@ import unittest from transformers import T5Config, is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device +from transformers.testing_utils import ( + require_accelerate, + require_sentencepiece, + require_tokenizers, + require_torch, + slow, + torch_device, +) from transformers.utils import cached_property from ...generation.test_utils import GenerationTesterMixin @@ -820,6 +827,50 @@ def use_task_specific_params(model, task): model.config.update(model.config.task_specific_params[task]) +@require_torch +@require_accelerate +@require_tokenizers +@slow +class T5ModelFp16Tests(unittest.TestCase): + def test_fp16_fp32_conversion(self): + r""" + A test to check whether the argument `keep_in_fp32_modules` correctly does its job + """ + # Load without using `accelerate` + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load without in bf16 + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto") + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = T5ForConditionalGeneration.from_pretrained( + "t5-small", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load without using `accelerate` + model = T5ForConditionalGeneration.from_pretrained( + "t5-small", torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load using `accelerate` + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, device_map="auto") + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + @require_torch @require_sentencepiece @require_tokenizers