Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
98b2afd
Initial commit
Rocketknight1 Mar 10, 2023
e557c34
more stash commit
Rocketknight1 Mar 14, 2023
87767b0
Yet another stash commit
Rocketknight1 Mar 20, 2023
d86ec34
yet more stash commit
Rocketknight1 Mar 21, 2023
35deb28
Mostly working except for docs / repo consistency
Rocketknight1 Mar 24, 2023
0a720e4
Stop importing model list from torch file
Rocketknight1 Mar 24, 2023
490fc63
Add TF BLIP models to docs
Rocketknight1 Mar 24, 2023
6dc06bb
Add auto classes
Rocketknight1 Mar 24, 2023
9fd4b76
Move get_text_features and get_image_features
Rocketknight1 Mar 24, 2023
07f99eb
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
8cfc37d
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
1c47a2f
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
70cfe55
Update src/transformers/models/blip/modeling_tf_blip_text.py
Rocketknight1 Mar 27, 2023
2024f5e
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
cc1694d
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
f31e96b
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
e12e305
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
2d622f6
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
ad2c87c
Update tests/models/blip/test_modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
6b781df
Update tests/models/blip/test_modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
dab565b
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
ee823fc
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
d6c5869
Update tests/models/blip/test_modeling_tf_blip_text.py
Rocketknight1 Mar 27, 2023
cf307fa
Update src/transformers/models/blip/modeling_tf_blip_text.py
Rocketknight1 Mar 27, 2023
0289c28
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
c4a4b62
Use channels_last convolutions in TF (better performance + compatibil…
Rocketknight1 Mar 27, 2023
3a082f8
Remove _shape function
Rocketknight1 Mar 27, 2023
8e73e08
Move multi-line statement to one line in PT + TF
Rocketknight1 Mar 27, 2023
7d0f73b
Specify tf.keras.layers instead of importing from it
Rocketknight1 Mar 27, 2023
4ec371b
Remove test_gradient_checkpointing and empty test_training methods
Rocketknight1 Mar 27, 2023
561d2f8
move some multi-line statements to one line
Rocketknight1 Mar 27, 2023
076948b
Update docstring for generate
Rocketknight1 Mar 27, 2023
429c25e
Remove pruned heads set
Rocketknight1 Mar 27, 2023
3086257
Remove self.seq_len_dim
Rocketknight1 Mar 27, 2023
adb0330
Fixed issues with loss computation, should resolve some tests. Also e…
Rocketknight1 Mar 29, 2023
fba2385
ensure original model follows config in more cases
Rocketknight1 Mar 30, 2023
f6c328e
Skip the same cross-attention tests in the PT tests - didn't realize …
Rocketknight1 Mar 30, 2023
4d71a05
Add training args throughout the models and layers
Rocketknight1 Mar 30, 2023
7239db5
make fixup
Rocketknight1 Mar 30, 2023
09592b2
Fix docstring for inputs_embeds
Rocketknight1 Mar 30, 2023
d4a6fa6
Add docstring for is_decoder
Rocketknight1 Mar 30, 2023
60f078c
Add docstrings to text models
Rocketknight1 Mar 30, 2023
e6a7851
Remove redundant computation
Rocketknight1 Mar 30, 2023
f3062b1
Add unpack_inputs / keras_serializable
Rocketknight1 Mar 30, 2023
77e365e
Add modeling_tf_blip to doctests
Rocketknight1 Mar 30, 2023
6fff45c
Add config classes for keras serialization
Rocketknight1 Mar 30, 2023
34463ea
Changes to allow model porting with pt-to-tf
Rocketknight1 Mar 31, 2023
60b7fb7
Quick fix to decoder head and test tweaks
Rocketknight1 Apr 3, 2023
2a7f52d
Revert an issue with masking the embeddings outputs
Rocketknight1 Apr 3, 2023
d962ac6
Allow missing keys in some equivalence tests (for unused layers)
Rocketknight1 Apr 3, 2023
0a43f85
Add tf-pt equivalence tests back in
Rocketknight1 Apr 3, 2023
09095d1
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Apr 3, 2023
dd88c83
Update src/transformers/models/blip/modeling_tf_blip_text.py
Rocketknight1 Apr 3, 2023
d0fd3d4
Update src/transformers/models/blip/modeling_tf_blip_text.py
Rocketknight1 Apr 3, 2023
9efd53c
make fixup
Rocketknight1 Apr 3, 2023
afd5a9c
Refactor invert_attention_mask out into tf_utils
Rocketknight1 Apr 3, 2023
41fe5e1
Re-enable cross-tests on the PT side too
Rocketknight1 Apr 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ Flax), PyTorch, and/or TensorFlow.
| BiT | ❌ | ❌ | ✅ | ❌ | ❌ |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
| BLIP | ❌ | ❌ | ✅ | | ❌ |
| BLIP | ❌ | ❌ | ✅ | | ❌ |
| BLIP-2 | ❌ | ❌ | ✅ | ❌ | ❌ |
| BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ |
| BridgeTower | ❌ | ❌ | ✅ | ❌ | ❌ |
Expand Down
40 changes: 38 additions & 2 deletions docs/source/en/model_doc/blip.mdx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -93,4 +93,40 @@ The original code can be found [here](https://github.com/salesforce/BLIP).
## BlipForQuestionAnswering

[[autodoc]] BlipForQuestionAnswering
- forward
- forward

## TFBlipModel

[[autodoc]] TFBlipModel
- call
- get_text_features
- get_image_features

## TFBlipTextModel

[[autodoc]] TFBlipTextModel
- call


## TFBlipVisionModel

[[autodoc]] TFBlipVisionModel
- call


## TFBlipForConditionalGeneration

[[autodoc]] TFBlipForConditionalGeneration
- call


## TFBlipForImageTextRetrieval

[[autodoc]] TFBlipForImageTextRetrieval
- call


## TFBlipForQuestionAnswering

[[autodoc]] TFBlipForQuestionAnswering
- call
22 changes: 22 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2902,6 +2902,18 @@
_import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
)
_import_structure["models.blip"].extend(
[
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFBlipForConditionalGeneration",
"TFBlipForImageTextRetrieval",
"TFBlipForQuestionAnswering",
"TFBlipModel",
"TFBlipPreTrainedModel",
"TFBlipTextModel",
"TFBlipVisionModel",
]
)
_import_structure["models.camembert"].extend(
[
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -6143,6 +6155,16 @@
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
from .models.blip import (
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipPreTrainedModel,
TFBlipTextModel,
TFBlipVisionModel,
)
from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM,
Expand Down
20 changes: 15 additions & 5 deletions src/transformers/commands/pt_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __init__(
self._extra_commit_description = extra_commit_description
self._override_model_class = override_model_class

def get_inputs(self, pt_model, config):
def get_inputs(self, pt_model, tf_dummy_inputs, config):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here seem unrelated to this PR and would be better in their own PR, no?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair! I added them because they were needed for the pt-to-tf code to port the BLIP models correctly. If you'd rather I move them to a separate PR though, that's fine!

"""
Returns the right inputs for the model, based on its signature.
"""
Expand Down Expand Up @@ -255,7 +255,11 @@ def _get_audio_input():
tf_input = processor(**processor_inputs, return_tensors="tf")

# Extra input requirements, in addition to the input modality
if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")):
if (
config.is_encoder_decoder
or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"))
or "decoder_input_ids" in tf_dummy_inputs
):
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
Expand Down Expand Up @@ -306,18 +310,24 @@ def run(self):
except AttributeError:
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")

# Load models and acquire a basic input compatible with the model.
# Check the TF dummy inputs to see what keys we need in the forward pass
tf_from_pt_model = tf_class.from_config(config)
tf_dummy_inputs = tf_from_pt_model.dummy_inputs

del tf_from_pt_model # Try to keep only one model in memory at a time

# Load the model and get some basic inputs
pt_model = pt_class.from_pretrained(self._local_dir)
pt_model.eval()

pt_input, tf_input = self.get_inputs(pt_model, config)
pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config)

with torch.no_grad():
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint

tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False)

# Confirms that cross loading PT weights into TF worked.
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def unpack_inputs(func):
func (`callable`):
The callable function of the TensorFlow model.


Returns:
A callable that wraps the original `func` with the behavior described above.
"""
Expand Down Expand Up @@ -1157,6 +1158,38 @@ def _from_config(cls, config, **kwargs):
"""
return cls(config, **kwargs)

def get_head_mask(self, head_mask: Optional[tf.Tensor], num_hidden_layers: int) -> tf.Tensor:
"""
Prepare the head mask if needed.

Args:
head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (`int`):
The number of hidden layers in the model.

Returns:
`tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
else:
head_mask = [None] * num_hidden_layers

return head_mask

def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.shape.rank == 1:
head_mask = head_mask[None, None, :, None, None]
head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
elif head_mask.shape.rank == 2:
head_mask = head_mask[:, None, :, None, None]
assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
return head_mask

def eager_serving(self, inputs):
"""
Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
("bert", "TFBertModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("blip", "TFBlipModel"),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are missing a few auto classes -- also missing on the PT side!

("camembert", "TFCamembertModel"),
("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"),
Expand Down Expand Up @@ -213,6 +214,7 @@
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Image Classification mapping
("blip", "TFBlipModel"),
("clip", "TFCLIPModel"),
]
)
Expand Down
42 changes: 41 additions & 1 deletion src/transformers/models/blip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)


_import_structure = {
Expand Down Expand Up @@ -52,6 +58,23 @@
"BlipForImageTextRetrieval",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_blip"] = [
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFBlipModel",
"TFBlipPreTrainedModel",
"TFBlipForConditionalGeneration",
"TFBlipForQuestionAnswering",
"TFBlipVisionModel",
"TFBlipTextModel",
"TFBlipForImageTextRetrieval",
]

if TYPE_CHECKING:
from .configuration_blip import BLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, BlipConfig, BlipTextConfig, BlipVisionConfig
from .processing_blip import BlipProcessor
Expand Down Expand Up @@ -81,6 +104,23 @@
BlipVisionModel,
)

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_blip import (
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipPreTrainedModel,
TFBlipTextModel,
TFBlipVisionModel,
)

else:
import sys

Expand Down
40 changes: 19 additions & 21 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,12 @@ def forward(

bsz, tgt_len, embed_dim = hidden_states.size()

mixed_qkv = self.qkv(hidden_states)
mixed_qkv = (
self.qkv(hidden_states)
.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
Expand Down Expand Up @@ -587,9 +582,7 @@ def forward(
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
Embedded representation of the inputs. Should be float, not int tokens.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

Expand Down Expand Up @@ -824,10 +817,7 @@ def get_image_features(
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
)
vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict)

pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
Expand Down Expand Up @@ -993,6 +983,10 @@ def forward(
```"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

vision_outputs = self.vision_model(
pixel_values=pixel_values,
Expand Down Expand Up @@ -1037,7 +1031,7 @@ def generate(
Overrides *generate* function to be able to use the model as a conditional generator

Parameters:
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
The sequence used as a prompt for the generation.
Expand Down Expand Up @@ -1066,9 +1060,7 @@ def generate(
"""

batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
vision_outputs = self.vision_model(pixel_values=pixel_values)

image_embeds = vision_outputs[0]

Expand Down Expand Up @@ -1198,6 +1190,10 @@ def forward(
)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

vision_outputs = self.vision_model(
pixel_values=pixel_values,
Expand Down Expand Up @@ -1266,7 +1262,7 @@ def generate(
Parameters:
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
The sequence used as a prompt for the generation.
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
Expand Down Expand Up @@ -1295,9 +1291,7 @@ def generate(
2
```
"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
vision_outputs = self.vision_model(pixel_values=pixel_values)

image_embeds = vision_outputs[0]

Expand Down Expand Up @@ -1412,6 +1406,10 @@ def forward(
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

vision_outputs = self.vision_model(
pixel_values=pixel_values,
Expand Down
Loading