Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/clip.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ The resource should ideally demonstrate something new instead of duplicating an
[[autodoc]] FlaxCLIPTextModel
- __call__

## FlaxCLIPTextModelWithProjection

[[autodoc]] FlaxCLIPTextModelWithProjection
- __call__

## FlaxCLIPVisionModel

[[autodoc]] FlaxCLIPVisionModel
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3965,6 +3965,7 @@
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
Expand Down Expand Up @@ -7388,6 +7389,7 @@
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
Expand Down Expand Up @@ -167,6 +168,7 @@
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
Expand Down
102 changes: 102 additions & 0 deletions src/transformers/models/clip/modeling_flax_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,36 @@
"""


@flax.struct.dataclass
class FlaxCLIPTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.

Args:
text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPTextModel`].
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

text_embeds: jnp.ndarray = None
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxCLIPOutput(ModelOutput):
"""
Expand Down Expand Up @@ -1007,6 +1037,78 @@ class FlaxCLIPTextModel(FlaxCLIPTextPreTrainedModel):
)


class FlaxCLIPTextModelWithProjectionModule(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32

def setup(self):
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)

def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)

if not return_dict:
return (text_embeds, text_outputs[0]) + text_outputs[2:]

return FlaxCLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)


class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel):
module_class = FlaxCLIPTextModelWithProjectionModule


FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """
Returns:

Example:

```python
>>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection

>>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")

>>> outputs = model(**inputs)
>>> text_embeds = outputs.text_embeds
```
"""

overwrite_call_docstring(
FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING
)
append_replace_return_docstrings(
FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig
)


class FlaxCLIPVisionModule(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxCLIPTextModelWithProjection(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxCLIPTextPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]

Expand Down
9 changes: 7 additions & 2 deletions tests/models/clip/test_modeling_flax_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.clip.modeling_flax_clip import FlaxCLIPModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from transformers.models.clip.modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPVisionModel,
)

if is_torch_available():
import torch
Expand Down Expand Up @@ -315,7 +320,7 @@ def prepare_config_and_inputs_for_common(self):

@require_flax
class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else ()
all_model_classes = (FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection) if is_flax_available() else ()

def setUp(self):
self.model_tester = FlaxCLIPTextModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
"TFGroupViTTextModel",
"TFGroupViTVisionModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxWav2Vec2ForCTC",
"DetrForSegmentation",
Expand Down