diff --git a/src/transformers/models/clip/modeling_tf_clip.py b/src/transformers/models/clip/modeling_tf_clip.py index 5d2096200152..ad26a7bfc38d 100644 --- a/src/transformers/models/clip/modeling_tf_clip.py +++ b/src/transformers/models/clip/modeling_tf_clip.py @@ -551,11 +551,14 @@ def call( ) def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32): - - diag = tf.constant(0.0, shape=(seq_length,), dtype=dtype) + # It is possible with an unspecified sequence length for seq_length to be + # a runtime value, which is unsupported by tf.constant. Per the TensorFlow + # docs, tf.fill can handle runtime dynamic shapes: + # https://www.tensorflow.org/api_docs/python/tf/fill + diag = tf.cast(tf.fill((seq_length,), 0.0), dtype) # set an additive 2D attention mask with all places being masked - to_mask = tf.constant(-10000.0, shape=(seq_length, seq_length), dtype=dtype) + to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype) # set diagonal & lower triangular parts to 0 (i.e. the places not to be masked) # TIP: think the 2D matrix as the space of (query_seq, key_seq) @@ -1082,6 +1085,18 @@ def call( return outputs + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), + } + ] + ) + def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling: + output = self.call(inputs) + return self.serving_output(output) + def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None @@ -1123,7 +1138,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: } ] ) - def serving(self, inputs): + def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling: """ Method used for serving the model. @@ -1226,7 +1241,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: } ] ) - def serving(self, inputs): + def serving(self, inputs: Dict[str, tf.Tensor]) -> TFCLIPOutput: """ Method used for serving the model. @@ -1375,4 +1390,7 @@ def call( return outputs def serving_output(self, output: TFCLIPOutput) -> TFCLIPOutput: + # TODO: As is this currently fails with saved_model=True, because + # TensorFlow cannot trace through nested dataclasses. Reference: + # https://github.com/huggingface/transformers/pull/16886 return output diff --git a/tests/models/clip/test_modeling_tf_clip.py b/tests/models/clip/test_modeling_tf_clip.py index ea572e6a2a51..797d5b73b349 100644 --- a/tests/models/clip/test_modeling_tf_clip.py +++ b/tests/models/clip/test_modeling_tf_clip.py @@ -256,6 +256,62 @@ def test_model_from_pretrained(self): model = TFCLIPVisionModel.from_pretrained(model_name) self.assertIsNotNone(model) + @slow + def test_saved_model_creation_extended(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + if hasattr(config, "use_cache"): + config.use_cache = True + + # in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = (self.model_tester.image_size, self.model_tester.image_size) + patch_size = (self.model_tester.patch_size, self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + + for model_class in self.all_model_classes: + class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + num_out = len(model(class_inputs_dict)) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, saved_model=True) + saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") + model = tf.keras.models.load_model(saved_model_dir) + outputs = model(class_inputs_dict) + output_hidden_states = outputs["hidden_states"] + output_attentions = outputs["attentions"] + + # Check num outputs + self.assertEqual(len(outputs), num_out) + + # Check num layers + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + + self.assertEqual(len(output_hidden_states), expected_num_layers) + self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers) + + # Check attention outputs + image_size = (self.model_tester.image_size, self.model_tester.image_size) + patch_size = (self.model_tester.patch_size, self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + + self.assertListEqual( + list(output_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_len, seq_len], + ) + + # Check hidden states + self.assertListEqual( + list(output_hidden_states[0].shape[-2:]), + [seq_len, self.model_tester.hidden_size], + ) + class TFCLIPTextModelTester: def __init__( @@ -367,6 +423,54 @@ def test_model_from_pretrained(self): model = TFCLIPTextModel.from_pretrained(model_name) self.assertIsNotNone(model) + @slow + def test_saved_model_creation_extended(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + if hasattr(config, "use_cache"): + config.use_cache = True + + for model_class in self.all_model_classes: + class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + num_out = len(model(class_inputs_dict)) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, saved_model=True) + saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") + model = tf.keras.models.load_model(saved_model_dir) + outputs = model(class_inputs_dict) + output_hidden_states = outputs["hidden_states"] + output_attentions = outputs["attentions"] + + # Check number of outputs + self.assertEqual(len(outputs), num_out) + + # Check number of layers + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + + # Check hidden states + self.assertEqual(len(output_hidden_states), expected_num_layers) + self.assertListEqual( + list(output_hidden_states[0].shape[-2:]), + [self.model_tester.seq_length, self.model_tester.hidden_size], + ) + + # Check attention outputs + self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers) + + seq_length = self.model_tester.seq_length + key_length = getattr(self.model_tester, "key_length", seq_length) + + self.assertListEqual( + list(output_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_length, key_length], + ) + class TFCLIPModelTester: def __init__(self, parent, is_training=True): @@ -502,6 +606,11 @@ def test_model_from_pretrained(self): model = TFCLIPModel.from_pretrained(model_name) self.assertIsNotNone(model) + @unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.") + @slow + def test_saved_model_creation_extended(self): + pass + # We will verify our results on an image of cute cats def prepare_img():