diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 5d463e5f7dc6..4154d104410a 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -288,7 +288,7 @@ Flax), PyTorch, and/or TensorFlow. | LayoutLMv2 | ✅ | ✅ | ✅ | ❌ | ❌ | | LayoutLMv3 | ✅ | ✅ | ✅ | ✅ | ❌ | | LED | ✅ | ✅ | ✅ | ✅ | ❌ | -| LeViT | ❌ | ❌ | ✅ | ❌ | ❌ | +| LeViT | ❌ | ❌ | ✅ | ✅ | ❌ | | LiLT | ❌ | ❌ | ✅ | ❌ | ❌ | | Longformer | ✅ | ✅ | ✅ | ✅ | ❌ | | LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ | diff --git a/docs/source/en/model_doc/levit.mdx b/docs/source/en/model_doc/levit.mdx index 0a64471b3480..45b6c720f83b 100644 --- a/docs/source/en/model_doc/levit.mdx +++ b/docs/source/en/model_doc/levit.mdx @@ -59,7 +59,8 @@ Tips: - You can check out demo notebooks regarding inference as well as fine-tuning on custom data [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer) (you can just replace [`ViTFeatureExtractor`] by [`LevitImageProcessor`] and [`ViTForImageClassification`] by [`LevitForImageClassification`] or [`LevitForImageClassificationWithTeacher`]). -This model was contributed by [anugunj](https://huggingface.co/anugunj). The original code can be found [here](https://github.com/facebookresearch/LeViT). +This model was contributed by [anugunj](https://huggingface.co/anugunj). The TensorFlow version was contributed by +[Aritra Roy Gosthipaty](https://huggingface.co/ariG23498). The original code can be found [here](https://github.com/facebookresearch/LeViT). ## LevitConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index bcd2e1cebee6..7db87f874a13 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2793,6 +2793,15 @@ ] ) _import_structure["models.led"].extend(["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"]) + _import_structure["models.levit"].extend( + [ + "TF_LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFLevitForImageClassification", + "TFLevitForImageClassificationWithTeacher", + "TFLevitModel", + "TFLevitPreTrainedModel", + ] + ) _import_structure["models.longformer"].extend( [ "TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5646,6 +5655,13 @@ TFLayoutLMv3PreTrainedModel, ) from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel + from .models.levit import ( + TF_LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFLevitForImageClassification, + TFLevitForImageClassificationWithTeacher, + TFLevitModel, + TFLevitPreTrainedModel, + ) from .models.longformer import ( TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TFLongformerForMaskedLM, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 729f83cd9f35..911ff385fbf1 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -58,6 +58,7 @@ ("layoutlm", "TFLayoutLMModel"), ("layoutlmv3", "TFLayoutLMv3Model"), ("led", "TFLEDModel"), + ("levit", "TFLevitModel"), ("longformer", "TFLongformerModel"), ("lxmert", "TFLxmertModel"), ("marian", "TFMarianModel"), diff --git a/src/transformers/models/levit/__init__.py b/src/transformers/models/levit/__init__.py index f42fb02ad071..7a52103e6d4d 100644 --- a/src/transformers/models/levit/__init__.py +++ b/src/transformers/models/levit/__init__.py @@ -17,7 +17,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_tf_available, + is_torch_available, + is_vision_available, +) _import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig", "LevitOnnxConfig"]} @@ -45,6 +51,20 @@ "LevitPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_levit"] = [ + "TF_LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFLevitForImageClassification", + "TFLevitForImageClassificationWithTeacher", + "TFLevitModel", + "TFLevitPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig, LevitOnnxConfig @@ -71,6 +91,20 @@ LevitModel, LevitPreTrainedModel, ) + + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_levit import ( + TF_LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFLevitForImageClassification, + TFLevitForImageClassificationWithTeacher, + TFLevitModel, + TFLevitPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/levit/modeling_tf_levit.py b/src/transformers/models/levit/modeling_tf_levit.py new file mode 100644 index 000000000000..a66f2cd59436 --- /dev/null +++ b/src/transformers/models/levit/modeling_tf_levit.py @@ -0,0 +1,1048 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TensorFlow LeViT model.""" + +import itertools +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import tensorflow as tf +from tensorflow.keras import backend as K +from tensorflow.keras.losses import BinaryCrossentropy, CategoricalCrossentropy, MeanSquaredError + +from ...modeling_outputs import ModelOutput +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithNoAttention, + TFBaseModelOutputWithPooling, + TFBaseModelOutputWithPoolingAndNoAttention, + TFImageClassifierOutputWithNoAttention, +) +from ...modeling_tf_utils import TFPreTrainedModel, keras_serializable, unpack_inputs +from ...tf_utils import shape_list, stable_softmax +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_levit import LevitConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "LevitConfig" +_FEAT_EXTRACTOR_FOR_DOC = "LevitFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/levit-128S" +_EXPECTED_OUTPUT_SHAPE = [1, 16, 384] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/levit-128S" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + +TF_LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/levit-128S", + # See all LeViT models at https://huggingface.co/models?filter=levit +] + + +@dataclass +class TFLevitForImageClassificationWithTeacherOutput(ModelOutput): + """ + Output type of [`TFLevitForImageClassificationWithTeacher`]. + + Args: + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores as the average of the `cls_logits` and `distillation_logits`. + cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the + class token). + distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the + distillation token). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + """ + + logits: tf.Tensor = None + cls_logits: tf.Tensor = None + distillation_logits: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + + +class TFLevitConvEmbeddings(tf.keras.layers.Layer): + """ + LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=1, + groups=1, + bn_weight_init=1, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + # The padding layer is built in order to pad the inputs before entering the convolution operation. + self.padding = tf.keras.layers.ZeroPadding2D(padding=padding) + self.convolution = tf.keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + strides=stride, + dilation_rate=dilation, + groups=groups, + use_bias=False, + data_format="channels_last", + name="convolution", + ) + # The epsilon and momentum used here are the defaults in torch batch norm layer. + self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-05, momentum=0.9, name="batch_norm") + + def call(self, embeddings: tf.Tensor, training: Optional[bool] = None): + # embeddings shape = (bsz, num_channels, height, width) + embeddings = tf.transpose(embeddings, perm=(0, 2, 3, 1)) + embeddings = self.padding(embeddings) + embeddings = self.convolution(embeddings, training=training) + embeddings = self.batch_norm(embeddings, training=training) + # embeddings shape = (bsz, height, width, num_channels) + embeddings = tf.transpose(embeddings, perm=(0, 3, 1, 2)) + return embeddings + + +# Defining hard swish with keras backend. +def hard_swish(x): + return x * (K.relu(x + 3.0, max_value=6.0) / 6.0) + + +class TFLevitPatchEmbeddings(tf.keras.layers.Layer): + """ + LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple + `TFLevitConvEmbeddings`. + """ + + def __init__(self, config, *args, **kwargs): + super().__init__(*args, **kwargs) + self.embedding_layer_1 = TFLevitConvEmbeddings( + in_channels=config.num_channels, + out_channels=config.hidden_sizes[0] // 8, + kernel_size=config.kernel_size, + stride=config.stride, + padding=config.padding, + name="embedding_layer_1", + ) + self.activation_layer_1 = hard_swish + + self.embedding_layer_2 = TFLevitConvEmbeddings( + in_channels=config.hidden_sizes[0] // 8, + out_channels=config.hidden_sizes[0] // 4, + kernel_size=config.kernel_size, + stride=config.stride, + padding=config.padding, + name="embedding_layer_2", + ) + self.activation_layer_2 = hard_swish + + self.embedding_layer_3 = TFLevitConvEmbeddings( + in_channels=config.hidden_sizes[0] // 4, + out_channels=config.hidden_sizes[0] // 2, + kernel_size=config.kernel_size, + stride=config.stride, + padding=config.padding, + name="embedding_layer_3", + ) + self.activation_layer_3 = hard_swish + + self.embedding_layer_4 = TFLevitConvEmbeddings( + in_channels=config.hidden_sizes[0] // 2, + out_channels=config.hidden_sizes[0], + kernel_size=config.kernel_size, + stride=config.stride, + padding=config.padding, + name="embedding_layer_4", + ) + self.num_channels = config.num_channels + + def call(self, pixel_values: tf.Tensor, training: Optional[bool] = None): + batch_size = tf.shape(pixel_values)[0] + num_channels = tf.shape(pixel_values)[1] + + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + embeddings = self.embedding_layer_1(pixel_values, training=training) + embeddings = self.activation_layer_1(embeddings) + embeddings = self.embedding_layer_2(embeddings, training=training) + embeddings = self.activation_layer_2(embeddings) + embeddings = self.embedding_layer_3(embeddings, training=training) + embeddings = self.activation_layer_3(embeddings) + embeddings = self.embedding_layer_4(embeddings, training=training) + + # Flatten the embeddings + num_channels = tf.shape(embeddings)[1] + flattended_embeddings = tf.reshape(embeddings, shape=(batch_size, num_channels, -1)) + # Transpose the channel and spatial axis of the flattened embeddings + transpose_embeddings = tf.transpose(flattended_embeddings, perm=(0, 2, 1)) + return transpose_embeddings + + +class TFMLPLayerWithBN(tf.keras.layers.Layer): + def __init__(self, input_dim, output_dim, bn_weight_init=1, *args, **kwargs): + super().__init__(*args, **kwargs) + self.linear = tf.keras.layers.Dense( + units=output_dim, + use_bias=False, + name="linear", + ) + # The epsilon and momentum used here are the defaults in torch batch norm layer. + self.batch_norm = tf.keras.layers.BatchNormalization( + epsilon=1e-05, + momentum=0.9, + name="batch_norm", + ) + + def call(self, hidden_state: tf.Tensor, training: Optional[bool] = None): + hidden_state = self.linear(hidden_state, training=training) + + # Before sending the hidden state to the batch normalization layer, we would have to + # flatten the hidden states with start=0 and end=1. + hidden_state_shape_list = shape_list(hidden_state) + hidden_state_reshape_list = [ + hidden_state_shape_list[0] * hidden_state_shape_list[1] + ] + hidden_state_shape_list[2:] + + flattened_hidden_state = tf.reshape(hidden_state, shape=hidden_state_reshape_list) + batch_norm_hidden_state = self.batch_norm(flattened_hidden_state, training=training) + + # Reshape the output of batch norm to have the same shape as the original hidden state + hidden_state = tf.reshape(batch_norm_hidden_state, shape=shape_list(hidden_state)) + return hidden_state + + +class TFLevitSubsample(tf.keras.layers.Layer): + """ + Layer to subsample the activatioin maps. + """ + + def __init__(self, stride, resolution, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stride = stride + self.resolution = resolution + + def call(self, hidden_state: tf.Tensor, training: Optional[bool] = None): + batch_size = tf.shape(hidden_state)[0] + channels = tf.shape(hidden_state)[2] + + reshaped_hidden_state = tf.reshape( + hidden_state, shape=(batch_size, self.resolution, self.resolution, channels) + ) + strided_hidden_state = reshaped_hidden_state[:, :: self.stride, :: self.stride] + hidden_state = tf.reshape(strided_hidden_state, shape=(batch_size, -1, channels)) + + return hidden_state + + +class TFLevitAttention(tf.keras.layers.Layer): + def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_attention_heads = num_attention_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.attention_ratio = attention_ratio + self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2 + self.out_dim_projection = attention_ratio * key_dim * num_attention_heads + + self.queries_keys_values = TFMLPLayerWithBN( + input_dim=hidden_sizes, output_dim=self.out_dim_keys_values, name="queries_keys_values" + ) + self.activation = hard_swish + self.projection = TFMLPLayerWithBN( + input_dim=self.out_dim_projection, output_dim=hidden_sizes, bn_weight_init=0, name="projection" + ) + + # Build tuples of points in the entire resolution range of the pixel values + points = list(itertools.product(range(resolution), range(resolution))) + self.len_points = len(points) + + # Initialize the attention offsets and indices + self.attention_offsets, self.indices = {}, [] + + # Iterate over the `points`` generator and calculate the offset between the initial + # point (0, 0) and the rest of the points [(0, 1), (0, 2)...] + for p1 in points: # this iterates only once, wehre p1 is (0, 0) + for p2 in points: # iterate over all the points other than (0, 0) + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in self.attention_offsets: + self.attention_offsets[offset] = len(self.attention_offsets) + self.indices.append(self.attention_offsets[offset]) + + # Store attention bias cache + self.attention_bias_cache = {} + + def build(self, input_shape: tf.TensorShape): + self.attention_biases = self.add_weight( + shape=(self.num_attention_heads, len(self.attention_offsets)), + initializer="zeros", + trainable=True, + name="attention_biases", + ) + self.attention_bias_idxs = tf.Variable( + initial_value=tf.reshape(self.indices, (self.len_points, self.len_points)), + trainable=False, # this is a registered buffer and not a parameter + dtype=tf.int32, + name="attention_bias_idxs", + ) + super().build(input_shape) + + def get_attention_biases(self, device, training: Optional[bool] = None): + if training: + return tf.gather(self.attention_biases, self.attention_bias_idxs, axis=1) + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = tf.gather( + self.attention_biases, self.attention_bias_idxs, axis=1 + ) + return self.attention_bias_cache[device_key] + + def call(self, hidden_state: tf.Tensor, training: Optional[bool] = None): + + # TODO: figure out the clearing cache mechanism + # if training and self.attention_bias_cache: + # self.attention_bias_cache = {} # clear ab cache + + batch_size = tf.shape(hidden_state)[0] + seq_length = tf.shape(hidden_state)[1] + queries_keys_values = self.queries_keys_values(hidden_state) + + # Reshape `queries_keys_values`. + reshaped_queries_keys_values = tf.reshape( + queries_keys_values, shape=(batch_size, seq_length, self.num_attention_heads, -1) + ) + # Split the reshaped tensor into query, key, and value. + query, key, value = tf.split( + value=reshaped_queries_keys_values, + num_or_size_splits=[self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], + axis=3, + ) + query = tf.transpose(query, perm=(0, 2, 1, 3)) + key = tf.transpose(key, perm=(0, 2, 1, 3)) + value = tf.transpose(value, perm=(0, 2, 1, 3)) + + attention = tf.matmul(query, key, transpose_b=True) * self.scale + self.get_attention_biases( + hidden_state.device, training=training + ) + attention = stable_softmax(attention, axis=-1) + hidden_state = tf.matmul(attention, value) + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3)) + hidden_state = tf.reshape(hidden_state, shape=(batch_size, seq_length, self.out_dim_projection)) + hidden_state = self.projection(self.activation(hidden_state)) + return hidden_state + + +class TFLevitAttentionSubsample(tf.keras.layers.Layer): + def __init__( + self, + input_dim, + output_dim, + key_dim, + num_attention_heads, + attention_ratio, + stride, + resolution_in, + resolution_out, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.num_attention_heads = num_attention_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.attention_ratio = attention_ratio + self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads + self.out_dim_projection = attention_ratio * key_dim * num_attention_heads + self.resolution_out = resolution_out + # resolution_in is the intial resolution, resoloution_out is final resolution after downsampling + self.keys_values = TFMLPLayerWithBN( + input_dim=input_dim, output_dim=self.out_dim_keys_values, name="keys_values" + ) + self.queries_subsample = TFLevitSubsample(stride=stride, resolution=resolution_in, name="queries_subsample") + self.queries = TFMLPLayerWithBN(input_dim=input_dim, output_dim=key_dim * num_attention_heads, name="queries") + self.activation = hard_swish + self.projection = TFMLPLayerWithBN(input_dim=self.out_dim_projection, output_dim=output_dim, name="projection") + + self.attention_bias_cache = {} + + points = list(itertools.product(range(resolution_in), range(resolution_in))) + points_ = list(itertools.product(range(resolution_out), range(resolution_out))) + self.len_points, self.len_points_ = len(points), len(points_) + attention_offsets, indices = {}, [] + for p1 in points_: + for p2 in points: + size = 1 + offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + indices.append(attention_offsets[offset]) + + self.attention_offsets = attention_offsets + self.indices = indices + + def build(self, input_shape: tf.TensorShape): + self.attention_biases = self.add_weight( + shape=(self.num_attention_heads, len(self.attention_offsets)), + initializer="zeros", + trainable=True, + name="attention_biases", + ) + + self.attention_bias_idxs = tf.Variable( + initial_value=tf.reshape(self.indices, (self.len_points_, self.len_points)), + trainable=False, + dtype=tf.int32, + name="attention_bias_idxs", + ) + super().build(input_shape) + + def get_attention_biases(self, device, training: Optional[bool] = None): + if training: + return tf.gather(self.attention_biases, self.attention_bias_idxs, axis=1) + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = tf.gather( + self.attention_biases, self.attention_bias_idxs, axis=1 + ) + return self.attention_bias_cache[device_key] + + def call(self, hidden_state: tf.Tensor, training: Optional[bool] = None): + + # TODO: figure out the clearing cache mechanism + # if training and self.attention_bias_cache: + # self.attention_bias_cache = {} # clear ab cache + + batch_size = tf.shape(hidden_state)[0] + seq_length = tf.shape(hidden_state)[1] + + # Process the hidden states and reshape it + reshaped_hidden_state = tf.reshape( + self.keys_values(hidden_state), shape=(batch_size, seq_length, self.num_attention_heads, -1) + ) + # Split the reshaped hidden state into key and value + key, value = tf.split( + reshaped_hidden_state, + num_or_size_splits=[self.key_dim, self.attention_ratio * self.key_dim], + axis=3, + ) + key = tf.transpose(key, perm=(0, 2, 1, 3)) + value = tf.transpose(value, perm=(0, 2, 1, 3)) + + query = self.queries(self.queries_subsample(hidden_state)) + query = tf.reshape(query, shape=(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim)) + query = tf.transpose(query, perm=(0, 2, 1, 3)) + + attention = tf.matmul(query, key, transpose_b=True) * self.scale + self.get_attention_biases( + hidden_state.device, training=training + ) + attention = stable_softmax(attention, axis=-1) + hidden_state = tf.matmul(attention, value) + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3)) + hidden_state = tf.reshape(hidden_state, (batch_size, -1, self.out_dim_projection)) + hidden_state = self.projection(self.activation(hidden_state), training=training) + return hidden_state + + +class TFLevitMLPLayer(tf.keras.layers.Layer): + """ + MLP Layer with `2X` expansion in contrast to ViT with `4X`. + """ + + def __init__(self, input_dim, hidden_dim, *args, **kwargs): + super().__init__(*args, **kwargs) + self.linear_up = TFMLPLayerWithBN(input_dim=input_dim, output_dim=hidden_dim, name="linear_up") + self.activation = hard_swish + self.linear_down = TFMLPLayerWithBN(input_dim=hidden_dim, output_dim=input_dim, name="linear_down") + + def call(self, hidden_state: tf.Tensor, training: Optional[bool] = None): + hidden_state = self.linear_up(hidden_state, training=training) + hidden_state = self.activation(hidden_state) + hidden_state = self.linear_down(hidden_state, training=training) + return hidden_state + + +class TFLevitResidualLayer(tf.keras.layers.Layer): + """ + Residual Block for TFLeViT + """ + + def __init__(self, module, drop_rate, *args, **kwargs): + super().__init__(*args, **kwargs) + self.module = module + self.drop_rate = drop_rate + + def call(self, hidden_state: tf.Tensor, training: Optional[bool] = None): + if training and self.drop_rate > 0.0: + rnd = tf.random.normal(shape=(tf.shape(hidden_state)[0], 1, 1), minval=0, maxval=1) + rnd = tf.math.greater(rnd, self.drop_rate) + rnd = tf.math.divide(rnd, (1 - self.drop_rate)) + # Detach the gradient from `rnd`. + tf.stop_gradient(rnd) + hidden_state = hidden_state + self.module(hidden_state) * rnd + return hidden_state + else: + hidden_state = hidden_state + self.module(hidden_state) + return hidden_state + + +class TFLevitStage(tf.keras.layers.Layer): + """ + LeViT Stage consisting of `TFLevitMLPLayer` and `TFLevitAttention` layers. + """ + + def __init__( + self, + config, + idx, + hidden_sizes, + key_dim, + depths, + num_attention_heads, + attention_ratio, + mlp_ratio, + down_ops, + resolution_in, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.layers = [] + self.config = config + self.resolution_in = resolution_in + # `resolution_in` is the intial resolution, `resolution_out` is final resolution after downsampling + index = 0 + for _ in range(depths): + self.layers.append( + TFLevitResidualLayer( + module=TFLevitAttention( + hidden_sizes=hidden_sizes, + key_dim=key_dim, + num_attention_heads=num_attention_heads, + attention_ratio=attention_ratio, + resolution=resolution_in, + name="module", + ), + drop_rate=self.config.drop_path_rate, + name=f"layers.{index}", + ) + ) + index += 1 # Increment the index by 1 + if mlp_ratio > 0: + hidden_dim = hidden_sizes * mlp_ratio + self.layers.append( + TFLevitResidualLayer( + module=TFLevitMLPLayer( + input_dim=hidden_sizes, + hidden_dim=hidden_dim, + name="module", + ), + drop_rate=self.config.drop_path_rate, + name=f"layers.{index}", + ) + ) + index += 1 # Increment the index by 1 + + if down_ops[0] == "Subsample": + self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1 + self.layers.append( + TFLevitAttentionSubsample( + *self.config.hidden_sizes[idx : idx + 2], + key_dim=down_ops[1], + num_attention_heads=down_ops[2], + attention_ratio=down_ops[3], + stride=down_ops[5], + resolution_in=resolution_in, + resolution_out=self.resolution_out, + name=f"layers.{index}", + ) + ) + index += 1 # Increment the index by 1 + self.resolution_in = self.resolution_out + if down_ops[4] > 0: + hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4] + self.layers.append( + TFLevitResidualLayer( + module=TFLevitMLPLayer( + input_dim=self.config.hidden_sizes[idx + 1], hidden_dim=hidden_dim, name="module" + ), + drop_rate=self.config.drop_path_rate, + name=f"layers.{index}", + ), + ) + index += 1 # Increment the index by 1 + + def get_resolution(self): + return self.resolution_in + + def call(self, hidden_state: tf.Tensor, training: Optional[bool] = None): + for layer in self.layers: + hidden_state = layer(hidden_state, training=training) + return hidden_state + + +class TFLevitEncoder(tf.keras.layers.Layer): + """ + LeViT Encoder consisting of multiple `TFLevitStage` stages. + """ + + def __init__(self, config, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config = config + resolution = self.config.image_size // self.config.patch_size + self.stages = [] + self.config.down_ops.append([""]) + + for stage_idx in range(len(config.depths)): + stage = TFLevitStage( + config=config, + idx=stage_idx, + hidden_sizes=config.hidden_sizes[stage_idx], + key_dim=config.key_dim[stage_idx], + depths=config.depths[stage_idx], + num_attention_heads=config.num_attention_heads[stage_idx], + attention_ratio=config.attention_ratio[stage_idx], + mlp_ratio=config.mlp_ratio[stage_idx], + down_ops=config.down_ops[stage_idx], + resolution_in=resolution, + name=f"stages.{stage_idx}", + ) + resolution = stage.get_resolution() + self.stages.append(stage) + + def call( + self, + hidden_state: tf.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + training: Optional[bool] = None, + ): + all_hidden_states = () if output_hidden_states else None + + for stage in self.stages: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + hidden_state = stage(hidden_state, training=training) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + if not return_dict: + return tuple(v for v in [hidden_state, all_hidden_states] if v is not None) + + return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states) + + +class TFLevitClassificationLayer(tf.keras.layers.Layer): + """ + LeViT Classification Layer + """ + + def __init__(self, input_dim, output_dim, *args, **kwargs): + super().__init__(*args, **kwargs) + + # The epsilon and momentum used here are the defaults in torch batch norm layer. + self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-05, momentum=0.9, name="batch_norm") + self.linear = tf.keras.layers.Dense(units=output_dim, name="linear") + + def call(self, hidden_state: tf.Tensor, training: Optional[bool] = None): + hidden_state = self.batch_norm(hidden_state, training=training) + logits = self.linear(hidden_state, training=training) + return logits + + +class TFLevitPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LevitConfig + base_model_prefix = "levit" + main_input_name = "pixel_values" + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform( + shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size), dtype=tf.float32 + ) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + @tf.function( + input_signature=[ + { + "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"), + } + ] + ) + def serving(self, inputs): + """ + Method used for serving the model. + + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + + return self.serving_output(output) + + +@keras_serializable +class TFLevitMainLayer(tf.keras.layers.Layer): + config_class = LevitConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config = config + self.patch_embeddings = TFLevitPatchEmbeddings(config=config, name="patch_embeddings") + self.encoder = TFLevitEncoder(config=config, name="encoder") + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = None, + ) -> Union[Tuple, TFBaseModelOutputWithPoolingAndNoAttention]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Apply patch embeddings to the pixel values + embeddings = self.patch_embeddings(pixel_values, training=training) + + # Apply encoder to the encoded pixel values + encoder_outputs = self.encoder( + embeddings, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + # Obtain the `last_hidden_state` + last_hidden_state = encoder_outputs[0] # encoder_outputs.last_hidden_state + + # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes) + pooled_output = tf.math.reduce_mean(last_hidden_state, axis=1) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, # only if the `output_hidden_states` is set to True + ) + + +LEVIT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`LevitConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +LEVIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Levit model outputting raw features without any specific head on top.", + LEVIT_START_DOCSTRING, +) +class TFLevitModel(TFLevitPreTrainedModel): + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.levit = TFLevitMainLayer(config=config, name="levit") + + @unpack_inputs + @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: tf.Tensor = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = None, + ): + outputs = self.levit( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def serving_output(self, output: TFBaseModelOutputWithPoolingAndNoAttention) -> TFBaseModelOutputWithPooling: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFBaseModelOutputWithPooling( + last_hidden_state=output.last_hidden_state, + pooler_output=output.pooler_output, + hidden_states=hs, + attentions=attns, + ) + + +@add_start_docstrings( + """ + Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + LEVIT_START_DOCSTRING, +) +class TFLevitForImageClassification(TFLevitPreTrainedModel): + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.num_labels = config.num_labels + self.levit = TFLevitMainLayer(config=config, name="levit") + + # Classifier head + self.classifier = ( + TFLevitClassificationLayer( + input_dim=config.hidden_sizes[-1], output_dim=config.num_labels, name="classifier" + ) + if config.num_labels > 0 + else tf.keras.layers.Activation("linear", name="classifier") + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: tf.Tensor = None, + labels: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = None, + ): + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + # Get the outputs from the levit main layer + outputs = self.levit( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + # Get the `last_hidden_state` and average it along the number of sequences + sequence_output = outputs[0] # outputs.last_hidden_state + sequence_output = tf.math.reduce_mean(sequence_output, axis=1) + + # Apply the classifier head and obtain the logits + logits = self.classifier(sequence_output, training=training) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MeanSquaredError() + if self.num_labels == 1: + loss = loss_fct(tf.squeeze(logits), tf.squeeze(labels)) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CategoricalCrossentropy() + loss = loss_fct(tf.reshape(logits, shape=(-1, self.num_labels)), tf.flatten(labels)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BinaryCrossentropy() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, # only if `output_hidden_states` flag is set to True + ) + + +@add_start_docstrings( + """ + LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and + a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning:: + This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet + supported. + """, + LEVIT_START_DOCSTRING, +) +class TFLevitForImageClassificationWithTeacher(TFLevitPreTrainedModel): + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.num_labels = config.num_labels + self.levit = TFLevitMainLayer(config, name="levit") + + # Classifier head + self.classifier = ( + TFLevitClassificationLayer( + input_dim=config.hidden_sizes[-1], + output_dim=config.num_labels, + name="classifier", + ) + if config.num_labels > 0 + else tf.keras.layers.Activation("linear", name="classifier") + ) + self.classifier_distill = ( + TFLevitClassificationLayer( + input_dim=config.hidden_sizes[-1], + output_dim=config.num_labels, + name="classifier_distill", + ) + if config.num_labels > 0 + else tf.keras.layers.Activation("linear", name="classifier_distill") + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFLevitForImageClassificationWithTeacherOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: tf.Tensor = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = None, + ): + # Get the output from the levit main layer + outputs = self.levit( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + # Get the `last_hidden_state` and average it along the number of sequences + sequence_output = outputs[0] # outputs.last_hidden_state + sequence_output = tf.math.reduce_mean(sequence_output, axis=1) + + # Apply the classifier heads and obtain the `cls_logits` and `distill_logits` + cls_logits = self.classifier(sequence_output, training=training) + distill_logits = self.classifier_distill(sequence_output, training=training) + + # According to the paper, the cls and distill logits are averaged + logits = (cls_logits + distill_logits) / 2 + + if not return_dict: + output = (logits, cls_logits, distill_logits) + outputs[2:] + return output + + return TFLevitForImageClassificationWithTeacherOutput( + logits=logits, + cls_logits=cls_logits, + distillation_logits=distill_logits, + hidden_states=outputs.hidden_states, # only if `output_hidden_states` flag is set to True + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index a24e53c50886..aaeb6529e228 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1506,6 +1506,37 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFLevitForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFLevitForImageClassificationWithTeacher(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFLevitModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFLevitPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None