-
Notifications
You must be signed in to change notification settings - Fork 33.5k
Add TF port of BLIP #22090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TF port of BLIP #22090
Changes from 46 commits
98b2afd
e557c34
87767b0
d86ec34
35deb28
0a720e4
490fc63
6dc06bb
9fd4b76
07f99eb
8cfc37d
1c47a2f
70cfe55
2024f5e
cc1694d
f31e96b
e12e305
2d622f6
ad2c87c
6b781df
dab565b
ee823fc
d6c5869
cf307fa
0289c28
c4a4b62
3a082f8
8e73e08
7d0f73b
4ec371b
561d2f8
076948b
429c25e
3086257
adb0330
fba2385
f6c328e
4d71a05
7239db5
09592b2
d4a6fa6
60f078c
e6a7851
f3062b1
77e365e
6fff45c
34463ea
60b7fb7
2a7f52d
d962ac6
0a43f85
09095d1
dd88c83
d0fd3d4
9efd53c
afd5a9c
41fe5e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -406,6 +406,7 @@ def unpack_inputs(func): | |
| func (`callable`): | ||
| The callable function of the TensorFlow model. | ||
|
|
||
|
|
||
| Returns: | ||
| A callable that wraps the original `func` with the behavior described above. | ||
| """ | ||
|
|
@@ -1157,6 +1158,65 @@ def _from_config(cls, config, **kwargs): | |
| """ | ||
| return cls(config, **kwargs) | ||
|
|
||
| def invert_attention_mask(self, encoder_attention_mask: tf.Tensor) -> tf.Tensor: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not use the state, so better put this as a function in We should probably cleanup the PyTorch side to do the same.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! I didn't touch the PyTorch side yet because that's a bigger refactor that touches several models, but I can do it in another PR after this if you want. |
||
| """ | ||
| Invert an attention mask (e.g., switches 0. and 1.). | ||
|
|
||
| Args: | ||
| encoder_attention_mask (`torch.Tensor`): An attention mask. | ||
|
|
||
| Returns: | ||
| `tf.Tensor`: The inverted attention mask. | ||
| """ | ||
| if not isinstance(encoder_attention_mask, tf.Tensor): | ||
| encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs | ||
| if encoder_attention_mask.shape.rank == 3: | ||
| encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] | ||
| if encoder_attention_mask.shape.rank == 2: | ||
| encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] | ||
| # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition | ||
| # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow | ||
| # /transformer/transformer_layers.py#L270 | ||
| # encoder_extended_attention_mask = (encoder_extended_attention_mask == | ||
| # encoder_extended_attention_mask.transpose(-1, -2)) | ||
| encoder_extended_attention_mask = ( | ||
| tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask | ||
| ) * encoder_extended_attention_mask.dtype.min | ||
|
|
||
| return encoder_extended_attention_mask | ||
|
|
||
| def get_head_mask(self, head_mask: Optional[tf.Tensor], num_hidden_layers: int) -> tf.Tensor: | ||
| """ | ||
| Prepare the head mask if needed. | ||
|
|
||
| Args: | ||
| head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): | ||
| The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). | ||
| num_hidden_layers (`int`): | ||
| The number of hidden layers in the model. | ||
|
|
||
| Returns: | ||
| `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with | ||
| `[None]` for each layer. | ||
| """ | ||
| if head_mask is not None: | ||
| head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) | ||
| else: | ||
| head_mask = [None] * num_hidden_layers | ||
|
|
||
| return head_mask | ||
|
|
||
| def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): | ||
| """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" | ||
| if head_mask.shape.rank == 1: | ||
| head_mask = head_mask[None, None, :, None, None] | ||
| head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0) | ||
| elif head_mask.shape.rank == 2: | ||
| head_mask = head_mask[:, None, :, None, None] | ||
| assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" | ||
| head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility | ||
| return head_mask | ||
|
|
||
| def eager_serving(self, inputs): | ||
| """ | ||
| Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,7 @@ | |
| ("bert", "TFBertModel"), | ||
| ("blenderbot", "TFBlenderbotModel"), | ||
| ("blenderbot-small", "TFBlenderbotSmallModel"), | ||
| ("blip", "TFBlipModel"), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we are missing a few auto classes -- also missing on the PT side! |
||
| ("camembert", "TFCamembertModel"), | ||
| ("clip", "TFCLIPModel"), | ||
| ("convbert", "TFConvBertModel"), | ||
|
|
@@ -213,6 +214,7 @@ | |
| TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | ||
| [ | ||
| # Model for Zero Shot Image Classification mapping | ||
| ("blip", "TFBlipModel"), | ||
| ("clip", "TFCLIPModel"), | ||
| ] | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.