diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d07ff8618c10..fff47c54764d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1318,9 +1318,11 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs): self.warnings_issued = {} # Overwrite the class attribute to make it an instance attribute, so models like # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute - # when a different component (e.g. language_model) is used. + # when a different component (e.g. language_model) is used. Same for `_tied_weights_keys` which pops/adds + # new keys dynamically depending on config values self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict) + self._tied_weights_keys = copy.copy(self.__class__._tied_weights_keys) self.dtype_plan = {} if isinstance(self._keep_in_fp32_modules, list): diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index 004cbc01e3e1..02013f1c8a68 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -68,7 +68,6 @@ def __init__( num_feature_levels=4, encoder_n_points=2, decoder_n_points=6, - tie_word_embeddings=False, ): self.parent = parent self.batch_size = batch_size @@ -89,7 +88,6 @@ def __init__( self.num_feature_levels = num_feature_levels self.encoder_n_points = encoder_n_points self.decoder_n_points = decoder_n_points - self.tie_word_embeddings = tie_word_embeddings # we also set the expected seq length for both encoder and decoder self.encoder_seq_length = ( @@ -151,9 +149,6 @@ def get_config(self): backbone=None, backbone_config=resnet_config, use_pretrained_backbone=False, - # FIXME; cls attr `toed_weihgt_keys` must not be modified in __init__ - # Several models affected so for now just let it be and fix in separate PR - tie_word_embeddings=self.tie_word_embeddings, ) def prepare_config_and_inputs_for_common(self): @@ -248,6 +243,25 @@ def test_deformable_detr_object_detection_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_deformable_detr_object_detection_head_model(*config_and_inputs) + def test_tie_weights_is_not_modified(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.tie_word_embeddings = True + + config.with_box_refine = True + config.two_stage = True + + model = DeformableDetrForObjectDetection(config) + self.assertTrue("model.decoder.bbox_embed" in model._tied_weights_keys) + self.assertTrue("model.decoder.class_embed" in model._tied_weights_keys) + + # if we update config attr, model's tied weights keys also change + config.with_box_refine = False + config.two_stage = False + + model = DeformableDetrForObjectDetection(config) + self.assertFalse("model.decoder.bbox_embed" in model._tied_weights_keys) + self.assertFalse("model.decoder.class_embed" in model._tied_weights_keys) + @unittest.skip(reason="Deformable DETR does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/grounding_dino/test_modeling_grounding_dino.py b/tests/models/grounding_dino/test_modeling_grounding_dino.py index 3896294719da..7002676fe4f7 100644 --- a/tests/models/grounding_dino/test_modeling_grounding_dino.py +++ b/tests/models/grounding_dino/test_modeling_grounding_dino.py @@ -322,6 +322,19 @@ def test_feed_forward_chunking(self): def test_load_save_without_tied_weights(self): pass + def test_tie_weights_is_not_modified(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.tie_word_embeddings = True + + config.decoder_bbox_embed_share = False + model = GroundingDinoForObjectDetection(config) + self.assertFalse(r"bbox_embed.(?![0])\d+" in model._tied_weights_keys) + + # if we update config attr, model's tied weights keys also change + config.decoder_bbox_embed_share = True + model = GroundingDinoForObjectDetection(config) + self.assertTrue(r"bbox_embed.(?![0])\d+" in model._tied_weights_keys) + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True diff --git a/tests/models/mm_grounding_dino/test_modeling_mm_grounding_dino.py b/tests/models/mm_grounding_dino/test_modeling_mm_grounding_dino.py index 3ae6cc9c834a..a02ca88249df 100644 --- a/tests/models/mm_grounding_dino/test_modeling_mm_grounding_dino.py +++ b/tests/models/mm_grounding_dino/test_modeling_mm_grounding_dino.py @@ -327,6 +327,11 @@ def test_feed_forward_chunking(self): def test_load_save_without_tied_weights(self): pass + # Ignore copy + def test_tie_weights_is_not_modified(self): + # this model doesn't need a test + pass + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True