From 94ca09e5ef60b2d7889686c20e6c71aac7d168ef Mon Sep 17 00:00:00 2001 From: Shubhamai Date: Wed, 1 Mar 2023 14:24:58 +0530 Subject: [PATCH 1/4] initial commit --- docs/source/de/index.mdx | 2 +- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/regnet.mdx | 14 +- docs/source/es/index.mdx | 2 +- docs/source/fr/index.mdx | 2 +- docs/source/it/index.mdx | 2 +- docs/source/ja/index.mdx | 2 +- docs/source/ko/index.mdx | 2 +- docs/source/pt/index.mdx | 2 +- docs/source/zh/index.mdx | 2 +- src/transformers/__init__.py | 4 + src/transformers/modeling_flax_outputs.py | 57 ++ .../models/auto/modeling_flax_auto.py | 2 + src/transformers/models/regnet/__init__.py | 32 +- .../models/regnet/modeling_flax_regnet.py | 818 ++++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 21 + .../regnet/test_modeling_flax_regnet.py | 240 +++++ 17 files changed, 1195 insertions(+), 11 deletions(-) create mode 100644 src/transformers/models/regnet/modeling_flax_regnet.py create mode 100644 tests/models/regnet/test_modeling_flax_regnet.py diff --git a/docs/source/de/index.mdx b/docs/source/de/index.mdx index f82aa44ea6bd..9b18216f6e77 100644 --- a/docs/source/de/index.mdx +++ b/docs/source/de/index.mdx @@ -281,7 +281,7 @@ Flax), PyTorch, und/oder TensorFlow haben. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index bec252c79e0f..af9ab31071e6 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -360,7 +360,7 @@ Flax), PyTorch, and/or TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/regnet.mdx b/docs/source/en/model_doc/regnet.mdx index e93eec6216d0..a6f7989355e5 100644 --- a/docs/source/en/model_doc/regnet.mdx +++ b/docs/source/en/model_doc/regnet.mdx @@ -67,4 +67,16 @@ If you're interested in submitting a resource to be included here, please feel f ## TFRegNetForImageClassification [[autodoc]] TFRegNetForImageClassification - - call \ No newline at end of file + - call + + +## FlaxRegNetModel + +[[autodoc]] FlaxRegNetModel + - __call__ + + +## FlaxRegNetForImageClassification + +[[autodoc]] FlaxRegNetForImageClassification + - __call__ \ No newline at end of file diff --git a/docs/source/es/index.mdx b/docs/source/es/index.mdx index 997b4f97460d..d2057a99982a 100644 --- a/docs/source/es/index.mdx +++ b/docs/source/es/index.mdx @@ -233,7 +233,7 @@ Flax), PyTorch y/o TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/fr/index.mdx b/docs/source/fr/index.mdx index f6b86d1bd328..337c17973cbe 100644 --- a/docs/source/fr/index.mdx +++ b/docs/source/fr/index.mdx @@ -345,7 +345,7 @@ Le tableau ci-dessous représente la prise en charge actuelle dans la bibliothè | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/it/index.mdx b/docs/source/it/index.mdx index 7be478bd7791..e343cf5e8761 100644 --- a/docs/source/it/index.mdx +++ b/docs/source/it/index.mdx @@ -250,7 +250,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗 | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/ja/index.mdx b/docs/source/ja/index.mdx index d98db263461b..04530db5ff24 100644 --- a/docs/source/ja/index.mdx +++ b/docs/source/ja/index.mdx @@ -336,7 +336,7 @@ specific language governing permissions and limitations under the License. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/ko/index.mdx b/docs/source/ko/index.mdx index 8073c2daeb83..a09d6ce3af2d 100644 --- a/docs/source/ko/index.mdx +++ b/docs/source/ko/index.mdx @@ -305,7 +305,7 @@ specific language governing permissions and limitations under the License. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/pt/index.mdx b/docs/source/pt/index.mdx index bdbd7385fdcf..6f01fc5c758d 100644 --- a/docs/source/pt/index.mdx +++ b/docs/source/pt/index.mdx @@ -249,7 +249,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/zh/index.mdx b/docs/source/zh/index.mdx index 4d69d590c692..5eb4c8bf229c 100644 --- a/docs/source/zh/index.mdx +++ b/docs/source/zh/index.mdx @@ -335,7 +335,7 @@ Flax), PyTorch, 和/或者 TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6514acd20389..2216ba1c5958 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3561,6 +3561,9 @@ "FlaxPegasusPreTrainedModel", ] ) + _import_structure["models.regnet"].extend( + ["FlaxRegNetForImageClassification", "FlaxRegNetModel", "FlaxRegNetPreTrainedModel"] + ) _import_structure["models.roberta"].extend( [ "FlaxRobertaForCausalLM", @@ -6548,6 +6551,7 @@ from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel + from .models.regnet import FlaxRegNetForImageClassification, FlaxRegNetModel, FlaxRegNetPreTrainedModel from .models.roberta import ( FlaxRobertaForCausalLM, FlaxRobertaForMaskedLM, diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py index 4f6cc5a901f8..620c29e76bbe 100644 --- a/src/transformers/modeling_flax_outputs.py +++ b/src/transformers/modeling_flax_outputs.py @@ -45,6 +45,63 @@ class FlaxBaseModelOutput(ModelOutput): attentions: Optional[Tuple[jnp.ndarray]] = None +@flax.struct.dataclass +class FlaxBaseModelOutputWithNoAttention(ModelOutput): + """ + Args: + Base class for model's outputs, with potential hidden states. + last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the + model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Args: + Base class for model's outputs that also contains a pooling of the last hidden states. + last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the + model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: jnp.ndarray = None + pooler_output: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxImageClassifierOutputWithNoAttention(ModelOutput): + """ + Args: + Base class for outputs of image classification models. + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + @flax.struct.dataclass class FlaxBaseModelOutputWithPast(ModelOutput): """ diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 77be9b33f0a7..690835863454 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -48,6 +48,7 @@ ("mt5", "FlaxMT5Model"), ("opt", "FlaxOPTModel"), ("pegasus", "FlaxPegasusModel"), + ("regnet", "FlaxRegNetModel"), ("roberta", "FlaxRobertaModel"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), ("roformer", "FlaxRoFormerModel"), @@ -119,6 +120,7 @@ [ # Model for Image-classsification ("beit", "FlaxBeitForImageClassification"), + ("regnet", "FlaxRegNetForImageClassification"), ("vit", "FlaxViTForImageClassification"), ] ) diff --git a/src/transformers/models/regnet/__init__.py b/src/transformers/models/regnet/__init__.py index 91221e9012cb..5084c4486008 100644 --- a/src/transformers/models/regnet/__init__.py +++ b/src/transformers/models/regnet/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tf_available, + is_torch_available, +) _import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]} @@ -44,6 +50,18 @@ "TFRegNetPreTrainedModel", ] +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_regnet"] = [ + "FlaxRegNetForImageClassification", + "FlaxRegNetModel", + "FlaxRegNetPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig @@ -74,6 +92,18 @@ TFRegNetPreTrainedModel, ) + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_regnet import ( + FlaxRegNetForImageClassification, + FlaxRegNetModel, + FlaxRegNetPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/regnet/modeling_flax_regnet.py b/src/transformers/models/regnet/modeling_flax_regnet.py new file mode 100644 index 000000000000..dffdc410fbcc --- /dev/null +++ b/src/transformers/models/regnet/modeling_flax_regnet.py @@ -0,0 +1,818 @@ +# coding=utf-8 +# Copyright 2023 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict + +from transformers import RegNetConfig +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutputWithNoAttention, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndNoAttention, + FlaxImageClassifierOutputWithNoAttention, +) +from transformers.modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, +) + + +REGNET_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +REGNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`RegNetImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Identity(nn.Module): + """Identity function.""" + + @nn.compact + def __call__(self, x, deterministic=None): + return x + + +class FlaxRegNetConvLayer(nn.Module): + out_channels: int + kernel_size: int = 3 + stride: int = 1 + groups: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(self.kernel_size, self.kernel_size), + strides=self.stride, + padding=self.kernel_size // 2, + feature_group_count=self.groups, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity() + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetEmbeddings(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.embedder = FlaxRegNetConvLayer( + self.config.embedding_size, + kernel_size=3, + stride=2, + activation=self.config.hidden_act, + dtype=self.dtype, + ) + + def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + num_channels = pixel_values.shape[-1] + if num_channels != self.config.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + hidden_state = self.embedder(pixel_values, deterministic=deterministic) + return hidden_state + + +class FlaxRegNetShortCut(nn.Module): + """ + RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + out_channels: int + stride: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(1, 1), + strides=self.stride, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + + def __call__(self, input: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + return hidden_state + + +class FlaxRegNetSELayerCollection(nn.Module): + in_channels: int + reduced_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv_1 = nn.Conv( + self.reduced_channels, + kernel_size=(1, 1), + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + name="0", + ) # 0 is the name used in corresponding pytorch implementation + self.conv_2 = nn.Conv( + self.in_channels, + kernel_size=(1, 1), + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + name="2", + ) # 2 is the name used in corresponding pytorch implementation + + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: + # b h w c -> b 1 1 c + + hidden_state = self.conv_1(hidden_state) + hidden_state = nn.relu(hidden_state) + hidden_state = self.conv_2(hidden_state) + attention = nn.sigmoid(hidden_state) + + return attention + + +class FlaxRegNetSELayer(nn.Module): + """ + Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507). + """ + + in_channels: int + reduced_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0))) + self.attention = FlaxRegNetSELayerCollection(self.in_channels, self.reduced_channels, dtype=self.dtype) + + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: + # b h w c -> b 1 1 c + pooled = self.pooler( + hidden_state, + window_shape=(hidden_state.shape[1], hidden_state.shape[2]), + strides=(hidden_state.shape[1], hidden_state.shape[2]), + ) + attention = self.attention(pooled) + hidden_state = hidden_state * attention + return hidden_state + + +class FlaxRegNetXLayerCollection(nn.Module): + config: RegNetConfig + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + groups = max(1, self.out_channels // self.config.groups_width) + + self.layer = [ + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + FlaxRegNetConvLayer( + self.out_channels, + stride=self.stride, + groups=groups, + activation=self.config.hidden_act, + dtype=self.dtype, + name="1", + ), + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=None, + dtype=self.dtype, + name="2", + ), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxRegNetXLayer(nn.Module): + """ + RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxRegNetShortCut( + self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + if should_apply_shortcut + else Identity() + ) + self.layer = FlaxRegNetXLayerCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetYLayerCollection(nn.Module): + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + groups = max(1, self.out_channels // self.config.groups_width) + + self.layer = [ + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + FlaxRegNetConvLayer( + self.out_channels, + stride=self.stride, + groups=groups, + activation=self.config.hidden_act, + dtype=self.dtype, + name="1", + ), + FlaxRegNetSELayer( + self.out_channels, + reduced_channels=int(round(self.in_channels / 4)), + dtype=self.dtype, + name="2", + ), + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=None, + dtype=self.dtype, + name="3", + ), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state) + return hidden_state + + +class FlaxRegNetYLayer(nn.Module): + """ + RegNet's Y layer: an X layer with Squeeze and Excitation. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + + self.shortcut = ( + FlaxRegNetShortCut( + self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + if should_apply_shortcut + else Identity() + ) + self.layer = FlaxRegNetYLayerCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetStageLayersCollection(nn.Module): + """ + A RegNet stage composed by stacked layers. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + layer = FlaxRegNetXLayer if self.config.layer_type == "x" else FlaxRegNetYLayer + + self.layers = [ + # downsampling is done in the first layer with stride of 2 + layer( + self.config, + self.in_channels, + self.out_channels, + stride=self.stride, + dtype=self.dtype, + name="0", + ), + *[ + layer( + self.config, + self.out_channels, + self.out_channels, + dtype=self.dtype, + name=str(i + 1), + ) + for i in range(self.depth - 1) + ], + ] + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = x + for layer in self.layers: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxRegNetStage(nn.Module): + """ + A RegNet stage composed by stacked layers. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = FlaxRegNetStageLayersCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + depth=self.depth, + dtype=self.dtype, + ) + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.layers(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxRegNetStageCollection(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) + self.stages = [ + FlaxRegNetStage( + self.config, + self.config.embedding_size, + self.config.hidden_sizes[0], + stride=2 if self.config.downsample_in_first_stage else 1, + depth=self.config.depths[0], + dtype=self.dtype, + name="0", + ), + *[ + FlaxRegNetStage( + self.config, + in_channels, + out_channels, + depth=depth, + dtype=self.dtype, + name=str(i + 1), + ) + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])) + ], + ] + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + hidden_state = stage_module(hidden_state, deterministic=deterministic) + + return hidden_state, hidden_states + + +class FlaxRegNetEncoder(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.stages = FlaxRegNetStageCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithPooling: + hidden_states = () if output_hidden_states else None + + hidden_state, hidden_states = self.stages( + hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic + ) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return FlaxBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + + +class FlaxRegNetPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RegNetConfig + base_model_prefix = "regnet" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: RegNetConfig, + input_shape=(1, 224, 224, 3), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__( + config, + module, + input_shape=input_shape, + seed=seed, + dtype=dtype, + _do_init=_do_init, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + rngs = {"params": rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False) + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values, + params: dict = None, + train: bool = False, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + + # Handle any PRNG if needed + rngs = {} + + return self.module.apply( + { + "params": params["params"] if params is not None else self.params["params"], + "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"], + }, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=["batch_stats"] if train else False, + ) + + +class FlaxRegNetModule(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embedder = FlaxRegNetEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRegNetEncoder(self.config, dtype=self.dtype) + self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0))) + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values: jnp.ndarray, + deterministic: bool = True, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values, deterministic=deterministic) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler( + last_hidden_state, + window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + ).transpose(0, 3, 1, 2) + + last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + "The bare RegNet model outputting raw features without any specific head on top.", + REGNET_START_DOCSTRING, +) +class FlaxRegNetModel(FlaxRegNetPreTrainedModel): + module_class = FlaxRegNetModule + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxRegNetModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040") + >>> model = FlaxRegNetModel.from_pretrained("facebook/regnet-y-040") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxRegNetModel, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings( + FlaxRegNetModel, + output_type=FlaxBaseModelOutputWithPooling, + config_class=RegNetConfig, +) + + +class FlaxRegNetClassifierCollection(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1") + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + return self.classifier(x) + + +class FlaxRegNetForImageClassificationModule(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.regnet = FlaxRegNetModule(config=self.config, dtype=self.dtype) + + if self.config.num_labels > 0: + self.classifier = FlaxRegNetClassifierCollection( + self.config, + dtype=self.dtype, + ) + else: + self.classifier = Identity() + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet( + pixel_values, + deterministic=deterministic, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output[:, :, 0, 0]) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + REGNET_START_DOCSTRING, +) +class FlaxRegNetForImageClassification(FlaxRegNetPreTrainedModel): + module_class = FlaxRegNetForImageClassificationModule + + +FLAX_VISION_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxRegNetForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040") + >>> model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxRegNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxRegNetForImageClassification, + output_type=FlaxImageClassifierOutputWithNoAttention, + config_class=RegNetConfig, +) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 60004790ec35..6c0dc8aa6a64 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -881,6 +881,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxRegNetForImageClassification(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRegNetModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRegNetPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxRobertaForCausalLM(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/regnet/test_modeling_flax_regnet.py b/tests/models/regnet/test_modeling_flax_regnet.py new file mode 100644 index 000000000000..4f3ffae84d08 --- /dev/null +++ b/tests/models/regnet/test_modeling_flax_regnet.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +from transformers import RegNetConfig, is_flax_available +from transformers.testing_utils import require_flax, slow +from transformers.utils import cached_property, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor + + +if is_flax_available(): + import jax + import jax.numpy as jnp + + from transformers.models.regnet.modeling_flax_regnet import FlaxRegNetForImageClassification, FlaxRegNetModel + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class FlaxRegNetModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=1, + image_size=224, + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return RegNetConfig( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + image_size=self.image_size, + ) + + def create_and_check_model(self, config, pixel_values): + model = FlaxRegNetModel(config=config) + result = model(pixel_values) + + # Output shape (b, c, h, w) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32), + ) + + def create_and_check_for_image_classification(self, config, pixel_values): + config.num_labels = self.num_labels + model = FlaxRegNetForImageClassification(config=config) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_flax +class FlaxResNetModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = (FlaxRegNetModel, FlaxRegNetForImageClassification) if is_flax_available() else () + + is_encoder_decoder = False + test_head_masking = False + has_attentions = False + + def setUp(self) -> None: + self.model_tester = FlaxRegNetModelTester(self) + self.config_tester = ConfigTester(self, config_class=RegNetConfig, has_text_modality=False) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @unittest.skip(reason="RegNet does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="RegNet does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_stages = self.model_tester.num_stages + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(pixel_values, **kwargs): + return model(pixel_values=pixel_values, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_flax +class FlaxRegNetModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return AutoFeatureExtractor.from_pretrained("facebook/regnet-y-040") if is_vision_available() else None + + @slow + def test_inference_image_classification_head(self): + model = FlaxRegNetForImageClassification.from_pretrained("Shubhamai/regnet-y-040") + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="np") + + outputs = model(**inputs) + + # verify the logits + expected_shape = (1, 1000) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = jnp.array([-0.4180, -1.5051, -3.4836]) + + self.assertTrue(jnp.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) From 331322863c1c6fb753f953d00df16a5b36f48252 Mon Sep 17 00:00:00 2001 From: Shubhamai Date: Sat, 25 Mar 2023 10:45:42 +0530 Subject: [PATCH 2/4] review changes --- .../models/regnet/modeling_flax_regnet.py | 104 +++++++++--------- .../models/resnet/modeling_flax_resnet.py | 2 +- .../regnet/test_modeling_flax_regnet.py | 9 +- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/src/transformers/models/regnet/modeling_flax_regnet.py b/src/transformers/models/regnet/modeling_flax_regnet.py index dffdc410fbcc..9fef1868d60a 100644 --- a/src/transformers/models/regnet/modeling_flax_regnet.py +++ b/src/transformers/models/regnet/modeling_flax_regnet.py @@ -90,11 +90,12 @@ """ +# Copied from transformers.models.resnet.modeling_flax_resnet.Identity class Identity(nn.Module): """Identity function.""" @nn.compact - def __call__(self, x, deterministic=None): + def __call__(self, x, **kwargs): return x @@ -150,6 +151,7 @@ def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp return hidden_state +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetShortCut with ResNet->RegNet class FlaxRegNetShortCut(nn.Module): """ RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to @@ -171,8 +173,8 @@ def setup(self): ) self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) - def __call__(self, input: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - hidden_state = self.convolution(input) + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) hidden_state = self.normalization(hidden_state, use_running_average=deterministic) return hidden_state @@ -199,8 +201,6 @@ def setup(self): ) # 2 is the name used in corresponding pytorch implementation def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: - # b h w c -> b 1 1 c - hidden_state = self.conv_1(hidden_state) hidden_state = nn.relu(hidden_state) hidden_state = self.conv_2(hidden_state) @@ -223,7 +223,6 @@ def setup(self): self.attention = FlaxRegNetSELayerCollection(self.in_channels, self.reduced_channels, dtype=self.dtype) def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: - # b h w c -> b 1 1 c pooled = self.pooler( hidden_state, window_shape=(hidden_state.shape[1], hidden_state.shape[2]), @@ -355,7 +354,7 @@ def setup(self): ), ] - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: for layer in self.layer: hidden_state = layer(hidden_state) return hidden_state @@ -417,7 +416,7 @@ class FlaxRegNetStageLayersCollection(nn.Module): def setup(self): layer = FlaxRegNetXLayer if self.config.layer_type == "x" else FlaxRegNetYLayer - self.layers = [ + layers = [ # downsampling is done in the first layer with stride of 2 layer( self.config, @@ -426,8 +425,11 @@ def setup(self): stride=self.stride, dtype=self.dtype, name="0", - ), - *[ + ) + ] + + for i in range(self.depth - 1): + layers.append( layer( self.config, self.out_channels, @@ -435,9 +437,9 @@ def setup(self): dtype=self.dtype, name=str(i + 1), ) - for i in range(self.depth - 1) - ], - ] + ) + + self.layers = layers def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: hidden_state = x @@ -446,6 +448,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: return hidden_state +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStage with ResNet->RegNet class FlaxRegNetStage(nn.Module): """ A RegNet stage composed by stacked layers. @@ -468,18 +471,18 @@ def setup(self): dtype=self.dtype, ) - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - hidden_state = self.layers(hidden_state, deterministic=deterministic) - return hidden_state + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + return self.layers(x, deterministic=deterministic) +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStageCollection with ResNet->RegNet class FlaxRegNetStageCollection(nn.Module): config: RegNetConfig dtype: jnp.dtype = jnp.float32 def setup(self): in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) - self.stages = [ + stages = [ FlaxRegNetStage( self.config, self.config.embedding_size, @@ -488,20 +491,16 @@ def setup(self): depth=self.config.depths[0], dtype=self.dtype, name="0", - ), - *[ - FlaxRegNetStage( - self.config, - in_channels, - out_channels, - depth=depth, - dtype=self.dtype, - name=str(i + 1), - ) - for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])) - ], + ) ] + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])): + stages.append( + FlaxRegNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1)) + ) + + self.stages = stages + def __call__( self, hidden_state: jnp.ndarray, @@ -519,6 +518,7 @@ def __call__( return hidden_state, hidden_states +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetEncoder with ResNet->RegNet class FlaxRegNetEncoder(nn.Module): config: RegNetConfig dtype: jnp.dtype = jnp.float32 @@ -532,9 +532,7 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, - ) -> FlaxBaseModelOutputWithPooling: - hidden_states = () if output_hidden_states else None - + ) -> FlaxBaseModelOutputWithNoAttention: hidden_state, hidden_states = self.stages( hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic ) @@ -545,9 +543,13 @@ def __call__( if not return_dict: return tuple(v for v in [hidden_state, hidden_states] if v is not None) - return FlaxBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + return FlaxBaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetPreTrainedModel with ResNet->RegNet,resnet->regnet,RESNET->REGNET class FlaxRegNetPreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -569,14 +571,9 @@ def __init__( **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__( - config, - module, - input_shape=input_shape, - seed=seed, - dtype=dtype, - _do_init=_do_init, - ) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors @@ -625,10 +622,11 @@ def __call__( output_hidden_states, return_dict, rngs=rngs, - mutable=["batch_stats"] if train else False, + mutable=["batch_stats"] if train else False, # Returing tuple with batch_stats only when train is True ) +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetModule with ResNet->RegNet class FlaxRegNetModule(nn.Module): config: RegNetConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -636,15 +634,19 @@ class FlaxRegNetModule(nn.Module): def setup(self): self.embedder = FlaxRegNetEmbeddings(self.config, dtype=self.dtype) self.encoder = FlaxRegNetEncoder(self.config, dtype=self.dtype) - self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0))) - @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + # Adaptive average pooling used in resnet + self.pooler = partial( + nn.avg_pool, + padding=((0, 0), (0, 0)), + ) + def __call__( self, - pixel_values: jnp.ndarray, + pixel_values, deterministic: bool = True, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + output_hidden_states: bool = False, + return_dict: bool = True, ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -718,6 +720,7 @@ class FlaxRegNetModel(FlaxRegNetPreTrainedModel): ) +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetClassifierCollection with ResNet->RegNet class FlaxRegNetClassifierCollection(nn.Module): config: RegNetConfig dtype: jnp.dtype = jnp.float32 @@ -729,6 +732,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return self.classifier(x) +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetForImageClassificationModule with ResNet->RegNet,resnet->regnet,RESNET->REGNET class FlaxRegNetForImageClassificationModule(nn.Module): config: RegNetConfig dtype: jnp.dtype = jnp.float32 @@ -737,14 +741,10 @@ def setup(self): self.regnet = FlaxRegNetModule(config=self.config, dtype=self.dtype) if self.config.num_labels > 0: - self.classifier = FlaxRegNetClassifierCollection( - self.config, - dtype=self.dtype, - ) + self.classifier = FlaxRegNetClassifierCollection(self.config, dtype=self.dtype) else: self.classifier = Identity() - @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) def __call__( self, pixel_values=None, diff --git a/src/transformers/models/resnet/modeling_flax_resnet.py b/src/transformers/models/resnet/modeling_flax_resnet.py index 36b286960743..875716d3f5be 100644 --- a/src/transformers/models/resnet/modeling_flax_resnet.py +++ b/src/transformers/models/resnet/modeling_flax_resnet.py @@ -89,7 +89,7 @@ class Identity(nn.Module): """Identity function.""" @nn.compact - def __call__(self, x): + def __call__(self, x, **kwargs): return x diff --git a/tests/models/regnet/test_modeling_flax_regnet.py b/tests/models/regnet/test_modeling_flax_regnet.py index 4f3ffae84d08..226e64737a07 100644 --- a/tests/models/regnet/test_modeling_flax_regnet.py +++ b/tests/models/regnet/test_modeling_flax_regnet.py @@ -40,8 +40,8 @@ class FlaxRegNetModelTester(unittest.TestCase): def __init__( self, parent, - batch_size=1, - image_size=224, + batch_size=3, + image_size=32, num_channels=3, embeddings_size=10, hidden_sizes=[10, 20, 30, 40], @@ -102,10 +102,7 @@ def create_and_check_for_image_classification(self, config, pixel_values): def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() - ( - config, - pixel_values, - ) = config_and_inputs + config, pixel_values = config_and_inputs inputs_dict = {"pixel_values": pixel_values} return config, inputs_dict From d9726740106c69c6ab6b280581759da18ad19485 Mon Sep 17 00:00:00 2001 From: Shubhamai Date: Sun, 26 Mar 2023 18:21:06 +0530 Subject: [PATCH 3/4] post model PR merge --- tests/models/regnet/test_modeling_flax_regnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/regnet/test_modeling_flax_regnet.py b/tests/models/regnet/test_modeling_flax_regnet.py index 226e64737a07..e9788ab09d7e 100644 --- a/tests/models/regnet/test_modeling_flax_regnet.py +++ b/tests/models/regnet/test_modeling_flax_regnet.py @@ -220,7 +220,7 @@ def default_feature_extractor(self): @slow def test_inference_image_classification_head(self): - model = FlaxRegNetForImageClassification.from_pretrained("Shubhamai/regnet-y-040") + model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040") feature_extractor = self.default_feature_extractor image = prepare_img() From a43f26c1390aa6ce3f71d0da55a65fd3488785ae Mon Sep 17 00:00:00 2001 From: Shubhamai Date: Sun, 26 Mar 2023 18:35:10 +0530 Subject: [PATCH 4/4] updating doc --- docs/source/es/index.mdx | 2 +- docs/source/it/index.mdx | 2 +- docs/source/pt/index.mdx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/es/index.mdx b/docs/source/es/index.mdx index 79186ee7b56c..49a4f83053cd 100644 --- a/docs/source/es/index.mdx +++ b/docs/source/es/index.mdx @@ -235,7 +235,7 @@ Flax), PyTorch y/o TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ✅ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/it/index.mdx b/docs/source/it/index.mdx index 35dea0ef2452..4c050bfe5224 100644 --- a/docs/source/it/index.mdx +++ b/docs/source/it/index.mdx @@ -252,7 +252,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗 | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ✅ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/pt/index.mdx b/docs/source/pt/index.mdx index e61d4f7780b7..9b5cbc12e610 100644 --- a/docs/source/pt/index.mdx +++ b/docs/source/pt/index.mdx @@ -250,7 +250,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ✅ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |