diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 94d6c33937a6..95a180dc5f48 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -265,17 +265,16 @@ def get_keys_to_not_convert(model): tied_keys = sum(tied_params, []) has_tied_params = len(tied_keys) > 0 - # Check if it is a base model - is_base_model = not hasattr(model, model.base_model_prefix) - - # Ignore this for base models (BertModel, GPT2Model, etc.) - if (not has_tied_params) and is_base_model: - return [] - - # otherwise they have an attached head + # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision + if not has_tied_params: + output_emb = model.get_output_embeddings() + if output_emb is not None: + list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] + return list_last_module + + # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision list_modules = list(model.named_parameters()) list_last_module = [list_modules[-1][0]] - # add last module together with tied weights intersection = set(list_last_module) - set(tied_keys) list_untouched = list(set(tied_keys)) + list(intersection) diff --git a/tests/bnb/test_mixed_int8.py b/tests/bnb/test_mixed_int8.py index f905b26e3f71..3e88a366d82b 100644 --- a/tests/bnb/test_mixed_int8.py +++ b/tests/bnb/test_mixed_int8.py @@ -124,6 +124,53 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() + def test_get_keys_to_not_convert(self): + r""" + Test the `get_keys_to_not_convert` function. + """ + from accelerate import init_empty_weights + + from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM + from transformers.utils.bitsandbytes import get_keys_to_not_convert + + model_id = "mosaicml/mpt-7b" + config = AutoConfig.from_pretrained( + model_id, trust_remote_code=True, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7" + ) + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"]) + # without trust_remote_code + config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7") + with init_empty_weights(): + model = MptForCausalLM(config) + # The order of the keys does not matter, so we sort them before comparing, same for the other tests. + self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "transformer.wte"].sort()) + + model_id = "Salesforce/blip2-opt-2.7b" + config = AutoConfig.from_pretrained(model_id, revision="1ef7f63a8f0a144c13fdca8103eb7b4691c74cec") + with init_empty_weights(): + model = Blip2ForConditionalGeneration(config) + self.assertEqual( + get_keys_to_not_convert(model).sort(), + ["language_model.lm_head", "language_model.model.decoder.embed_tokens"].sort(), + ) + + model_id = "facebook/opt-350m" + config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5") + with init_empty_weights(): + model = OPTForCausalLM(config) + self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "model.decoder.embed_tokens"].sort()) + + model_id = "roberta-large" + config = AutoConfig.from_pretrained(model_id, revision="716877d372b884cad6d419d828bac6c85b3b18d9") + with init_empty_weights(): + model = AutoModelForMaskedLM.from_config(config) + self.assertEqual( + get_keys_to_not_convert(model).sort(), + ["'roberta.embeddings.word_embeddings', 'lm_head', 'lm_head.decoder"].sort(), + ) + def test_quantization_config_json_serialization(self): r""" A simple test to check if the quantization config is correctly serialized and deserialized