Skip to content
28 changes: 23 additions & 5 deletions src/transformers/models/clip/modeling_tf_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,11 +551,14 @@ def call(
)

def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32):

diag = tf.constant(0.0, shape=(seq_length,), dtype=dtype)
# It is possible with an unspecified sequence length for seq_length to be
# a runtime value, which is unsupported by tf.constant. Per the TensorFlow
# docs, tf.fill can handle runtime dynamic shapes:
# https://www.tensorflow.org/api_docs/python/tf/fill
diag = tf.cast(tf.fill((seq_length,), 0.0), dtype)

# set an additive 2D attention mask with all places being masked
to_mask = tf.constant(-10000.0, shape=(seq_length, seq_length), dtype=dtype)
to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype)

# set diagonal & lower triangular parts to 0 (i.e. the places not to be masked)
# TIP: think the 2D matrix as the space of (query_seq, key_seq)
Expand Down Expand Up @@ -1082,6 +1085,18 @@ def call(

return outputs

@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling:
output = self.call(inputs)
return self.serving_output(output)

def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
Expand Down Expand Up @@ -1123,7 +1138,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
}
]
)
def serving(self, inputs):
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling:
"""
Method used for serving the model.

Expand Down Expand Up @@ -1226,7 +1241,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
}
]
)
def serving(self, inputs):
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFCLIPOutput:
"""
Method used for serving the model.

Expand Down Expand Up @@ -1375,4 +1390,7 @@ def call(
return outputs

def serving_output(self, output: TFCLIPOutput) -> TFCLIPOutput:
# TODO: As is this currently fails with saved_model=True, because
# TensorFlow cannot trace through nested dataclasses. Reference:
# https://github.com/huggingface/transformers/pull/16886
return output
109 changes: 109 additions & 0 deletions tests/models/clip/test_modeling_tf_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,62 @@ def test_model_from_pretrained(self):
model = TFCLIPVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True

if hasattr(config, "use_cache"):
config.use_cache = True

# in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = (self.model_tester.image_size, self.model_tester.image_size)
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1

for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
num_out = len(model(class_inputs_dict))

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict)
output_hidden_states = outputs["hidden_states"]
output_attentions = outputs["attentions"]

# Check num outputs
self.assertEqual(len(outputs), num_out)

# Check num layers
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)

self.assertEqual(len(output_hidden_states), expected_num_layers)
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)

# Check attention outputs
image_size = (self.model_tester.image_size, self.model_tester.image_size)
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1

self.assertListEqual(
list(output_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_len, seq_len],
)

# Check hidden states
self.assertListEqual(
list(output_hidden_states[0].shape[-2:]),
[seq_len, self.model_tester.hidden_size],
)


class TFCLIPTextModelTester:
def __init__(
Expand Down Expand Up @@ -367,6 +423,54 @@ def test_model_from_pretrained(self):
model = TFCLIPTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True

if hasattr(config, "use_cache"):
config.use_cache = True

for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
num_out = len(model(class_inputs_dict))

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict)
output_hidden_states = outputs["hidden_states"]
output_attentions = outputs["attentions"]

# Check number of outputs
self.assertEqual(len(outputs), num_out)

# Check number of layers
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)

# Check hidden states
self.assertEqual(len(output_hidden_states), expected_num_layers)
self.assertListEqual(
list(output_hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)

# Check attention outputs
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)

seq_length = self.model_tester.seq_length
key_length = getattr(self.model_tester, "key_length", seq_length)

self.assertListEqual(
list(output_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, key_length],
)


class TFCLIPModelTester:
def __init__(self, parent, is_training=True):
Expand Down Expand Up @@ -502,6 +606,11 @@ def test_model_from_pretrained(self):
model = TFCLIPModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
@slow
def test_saved_model_creation_extended(self):
pass


# We will verify our results on an image of cute cats
def prepare_img():
Expand Down