From 3f53abd132c91fd38bdaf67e08dbf7fd7d2c15a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Thu, 22 Aug 2024 14:32:17 +0200 Subject: [PATCH 01/44] first draft --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/ijepa.md | 51 + src/transformers/__init__.py | 41 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/ijepa/__init__.py | 76 ++ .../models/ijepa/configuration_ijepa.py | 140 +++ .../models/ijepa/convert_dino_to_pytorch.py | 298 ++++++ .../ijepa/convert_ijepa_timm_to_pytorch.py | 347 +++++++ .../models/ijepa/modeling_ijepa.py | 932 ++++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 21 + src/transformers/utils/dummy_pt_objects.py | 25 +- src/transformers/utils/dummy_tf_objects.py | 21 + .../utils/dummy_torchvision_objects.py | 7 + .../utils/dummy_vision_objects.py | 7 + tests/models/ijepa/__init__.py | 0 tests/models/ijepa/test_modeling_ijepa.py | 373 +++++++ 20 files changed, 2347 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/model_doc/ijepa.md create mode 100644 src/transformers/models/ijepa/__init__.py create mode 100644 src/transformers/models/ijepa/configuration_ijepa.py create mode 100644 src/transformers/models/ijepa/convert_dino_to_pytorch.py create mode 100644 src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py create mode 100644 src/transformers/models/ijepa/modeling_ijepa.py create mode 100644 tests/models/ijepa/__init__.py create mode 100644 tests/models/ijepa/test_modeling_ijepa.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index dc88bbd45ab2..805c616507ba 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -621,6 +621,8 @@ title: GLPN - local: model_doc/hiera title: Hiera + - local: model_doc/ijepa + title: I-JEPA - local: model_doc/imagegpt title: ImageGPT - local: model_doc/levit diff --git a/docs/source/en/index.md b/docs/source/en/index.md index ac73d9ab70fc..963f88410b4f 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -165,6 +165,7 @@ Flax), PyTorch, and/or TensorFlow. | [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ | | [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ | | [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ | +| [I-JEPA](model_doc/ijepa) | ✅ | ❌ | ❌ | | [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ | | [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ | | [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md new file mode 100644 index 000000000000..2ce748fb39f9 --- /dev/null +++ b/docs/source/en/model_doc/ijepa.md @@ -0,0 +1,51 @@ + + +# I-JEPA + +## Overview + +The I-JEPA model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## IJepaConfig + +[[autodoc]] IJepaConfig + +## IJepaModel + +[[autodoc]] IJepaModel + - forward + +## IJepaForImageClassification + +[[autodoc]] IJepaForImageClassification + - forward + + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1d36e7f8c746..39cefb38022b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -476,6 +476,7 @@ "models.ibert": ["IBertConfig"], "models.idefics": ["IdeficsConfig"], "models.idefics2": ["Idefics2Config"], + "models.ijepa": ["IJepaConfig"], "models.imagegpt": ["ImageGPTConfig"], "models.informer": ["InformerConfig"], "models.instructblip": [ @@ -1217,6 +1218,7 @@ ] else: _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] + _import_structure["models.ijepa"].append("IJEPAImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") # PyTorch-backed objects @@ -2377,6 +2379,13 @@ "Idefics2Processor", ] ) + _import_structure["models.ijepa"].extend( + [ + "IJepaForImageClassification", + "IJepaModel", + "IJepaPreTrainedModel", + ] + ) _import_structure["models.imagegpt"].extend( [ "ImageGPTForCausalImageModeling", @@ -4079,6 +4088,13 @@ ] ) + _import_structure["models.ijepa"].extend( + [ + "TFIJepaForImageClassification", + "TFIJepaModel", + "TFIJepaPreTrainedModel", + ] + ) _import_structure["models.layoutlm"].extend( [ "TFLayoutLMForMaskedLM", @@ -4739,6 +4755,13 @@ _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) + _import_structure["models.ijepa"].extend( + [ + "FlaxIJepaForImageClassification", + "FlaxIJepaModel", + "FlaxIJepaPreTrainedModel", + ] + ) _import_structure["models.wav2vec2"].extend( [ "FlaxWav2Vec2ForCTC", @@ -5219,6 +5242,7 @@ IdeficsConfig, ) from .models.idefics2 import Idefics2Config + from .models.ijepa import IJepaConfig from .models.imagegpt import ImageGPTConfig from .models.informer import InformerConfig from .models.instructblip import ( @@ -5941,6 +5965,7 @@ from .models.grounding_dino import GroundingDinoImageProcessor from .models.idefics import IdeficsImageProcessor from .models.idefics2 import Idefics2ImageProcessor + from .models.ijepa import IJepaImageProcessor from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor from .models.instructblipvideo import InstructBlipVideoImageProcessor from .models.layoutlmv2 import ( @@ -6003,6 +6028,7 @@ from .utils.dummy_torchvision_objects import * else: from .image_processing_utils_fast import BaseImageProcessorFast + from .models.ijepa import IJepaImageProcessorFast from .models.vit import ViTImageProcessorFast # Modeling @@ -6962,6 +6988,11 @@ Idefics2PreTrainedModel, Idefics2Processor, ) + from .models.ijepa import ( + IJepaForImageClassification, + IJepaModel, + IJepaPreTrainedModel, + ) from .models.imagegpt import ( ImageGPTForCausalImageModeling, ImageGPTForImageClassification, @@ -8307,6 +8338,11 @@ TFIdeficsModel, TFIdeficsPreTrainedModel, ) + from .models.ijepa import ( + TFIJepaForImageClassification, + TFIJepaModel, + TFIJepaPreTrainedModel, + ) from .models.layoutlm import ( TFLayoutLMForMaskedLM, TFLayoutLMForQuestionAnswering, @@ -8774,6 +8810,11 @@ FlaxGPTJModel, FlaxGPTJPreTrainedModel, ) + from .models.ijepa import ( + FlaxIJepaForImageClassification, + FlaxIJepaModel, + FlaxIJepaPreTrainedModel, + ) from .models.llama import ( FlaxLlamaForCausalLM, FlaxLlamaModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f60f72a23614..421ea4716d38 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -113,6 +113,7 @@ ibert, idefics, idefics2, + ijepa, imagegpt, informer, instructblip, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index ecd0a6674041..f337fc8bb6ef 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -130,6 +130,7 @@ ("ibert", "IBertConfig"), ("idefics", "IdeficsConfig"), ("idefics2", "Idefics2Config"), + ("ijepa", "IJepaConfig"), ("imagegpt", "ImageGPTConfig"), ("informer", "InformerConfig"), ("instructblip", "InstructBlipConfig"), @@ -419,6 +420,7 @@ ("ibert", "I-BERT"), ("idefics", "IDEFICS"), ("idefics2", "Idefics2"), + ("ijepa", "I-JEPA"), ("imagegpt", "ImageGPT"), ("informer", "Informer"), ("instructblip", "InstructBLIP"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 7f335d66584f..bba05458851d 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -63,6 +63,7 @@ ("glpn", "GLPNFeatureExtractor"), ("groupvit", "CLIPFeatureExtractor"), ("hubert", "Wav2Vec2FeatureExtractor"), + ("ijepa", "IJepaFeatureExtractor"), ("imagegpt", "ImageGPTFeatureExtractor"), ("layoutlmv2", "LayoutLMv2FeatureExtractor"), ("layoutlmv3", "LayoutLMv3FeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 38086aa0f2e9..4ff1e87a1dc1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -127,6 +127,7 @@ ("ibert", "IBertModel"), ("idefics", "IdeficsModel"), ("idefics2", "Idefics2Model"), + ("ijepa", "IJepaModel"), ("imagegpt", "ImageGPTModel"), ("informer", "InformerModel"), ("jamba", "JambaModel"), @@ -552,6 +553,7 @@ ("focalnet", "FocalNetModel"), ("glpn", "GLPNModel"), ("hiera", "HieraModel"), + ("ijepa", "IJepaModel"), ("imagegpt", "ImageGPTModel"), ("levit", "LevitModel"), ("mobilenet_v1", "MobileNetV1Model"), diff --git a/src/transformers/models/ijepa/__init__.py b/src/transformers/models/ijepa/__init__.py new file mode 100644 index 000000000000..685ef0e255c5 --- /dev/null +++ b/src/transformers/models/ijepa/__init__.py @@ -0,0 +1,76 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_torchvision_available, +) + +_import_structure = {"configuration_ijepa": ["IJepaConfig", "IJepaOnnxConfig"]} + +try: + if not is_torchvision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_ijepa_fast"] = [ + "IJepaImageProcessorFast" + ] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_ijepa"] = [ + "IJepaForImageClassification", + "IJepaModel", + "IJepaPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_ijepa import IJepaConfig, IJepaOnnxConfig + + try: + if not is_torchvision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_ijepa import ( + IJepaForImageClassification, + IJepaModel, + IJepaPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/src/transformers/models/ijepa/configuration_ijepa.py b/src/transformers/models/ijepa/configuration_ijepa.py new file mode 100644 index 000000000000..09481db69348 --- /dev/null +++ b/src/transformers/models/ijepa/configuration_ijepa.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""I-JEPA model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + +logger = logging.get_logger(__name__) + + +class IJepaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`IJepaModel`]. It is used to instantiate an IJEPA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the I-JEPA + [google/ijepa-base-patch16-224](https://huggingface.co/google/ijepa-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + encoder_stride (`int`, *optional*, defaults to 16): + Factor to increase the spatial resolution by in the decoder head for masked image modeling. + + Example: + + ```python + >>> from transformers import IJepaConfig, IJepaModel + + >>> # Initializing a IJEPA ijepa-base-patch16-224 style configuration + >>> configuration = IJepaConfig() + + >>> # Initializing a model (with random weights) from the ijepa-base-patch16-224 style configuration + >>> model = IJepaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ijepa" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + encoder_stride=16, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.encoder_stride = encoder_stride + + +class IJepaOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ( + "pixel_values", + {0: "batch", 1: "num_channels", 2: "height", 3: "width"}, + ), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 diff --git a/src/transformers/models/ijepa/convert_dino_to_pytorch.py b/src/transformers/models/ijepa/convert_dino_to_pytorch.py new file mode 100644 index 000000000000..351a4d11e179 --- /dev/null +++ b/src/transformers/models/ijepa/convert_dino_to_pytorch.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert IJEPA checkpoints trained with the DINO method.""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + IJepaConfig, + IJepaForImageClassification, + IJepaImageProcessor, + IJepaModel, +) +from transformers.utils import logging + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append( + ( + f"blocks.{i}.norm1.weight", + f"ijepa.encoder.layer.{i}.layernorm_before.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.norm1.bias", + f"ijepa.encoder.layer.{i}.layernorm_before.bias", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.attn.proj.weight", + f"ijepa.encoder.layer.{i}.attention.output.dense.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.attn.proj.bias", + f"ijepa.encoder.layer.{i}.attention.output.dense.bias", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.norm2.weight", + f"ijepa.encoder.layer.{i}.layernorm_after.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.norm2.bias", + f"ijepa.encoder.layer.{i}.layernorm_after.bias", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc1.weight", + f"ijepa.encoder.layer.{i}.intermediate.dense.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc1.bias", + f"ijepa.encoder.layer.{i}.intermediate.dense.bias", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc2.weight", + f"ijepa.encoder.layer.{i}.output.dense.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc2.bias", + f"ijepa.encoder.layer.{i}.output.dense.bias", + ) + ) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("cls_token", "ijepa.embeddings.cls_token"), + ( + "patch_embed.proj.weight", + "ijepa.embeddings.patch_embeddings.projection.weight", + ), + ( + "patch_embed.proj.bias", + "ijepa.embeddings.patch_embeddings.projection.bias", + ), + ("pos_embed", "ijepa.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + pooler + rename_keys.extend( + [ + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ] + ) + + # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" + rename_keys = [ + (pair[0], pair[1][4:]) if pair[1].startswith("ijepa") else pair + for pair in rename_keys + ] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "ijepa.layernorm.weight"), + ("norm.bias", "ijepa.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "ijepa." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.query.weight" + ] = in_proj_weight[: config.hidden_size, :] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.query.bias" + ] = in_proj_bias[: config.hidden_size] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.key.weight" + ] = in_proj_weight[config.hidden_size : config.hidden_size * 2, :] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.key.bias" + ] = in_proj_bias[config.hidden_size : config.hidden_size * 2] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.value.weight" + ] = in_proj_weight[-config.hidden_size :, :] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.value.bias" + ] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_ijepa_checkpoint( + model_name, pytorch_dump_folder_path, base_model=True +): + """ + Copy/paste/tweak model's weights to our IJEPA structure. + """ + + # define default IJEPA configuration + config = IJepaConfig() + # patch_size + if model_name[-1] == "8": + config.patch_size = 8 + # set labels if required + if not base_model: + config.num_labels = 1000 + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load( + open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r") + ) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + # size of the architecture + if model_name in ["dino_ijepas8", "dino_ijepas16"]: + config.hidden_size = 384 + config.intermediate_size = 1536 + config.num_hidden_layers = 12 + config.num_attention_heads = 6 + + # load original model from torch hub + original_model = torch.hub.load("facebookresearch/dino:main", model_name) + original_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = original_model.state_dict() + if base_model: + remove_classification_head_(state_dict) + rename_keys = create_rename_keys(config, base_model=base_model) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model) + + # load HuggingFace model + if base_model: + model = IJepaModel(config, add_pooling_layer=False).eval() + else: + model = IJepaForImageClassification(config).eval() + model.load_state_dict(state_dict) + + # Check outputs on an image, prepared by IJepaImageProcessor + image_processor = IJepaImageProcessor() + encoding = image_processor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + + if base_model: + final_hidden_state_cls_token = original_model(pixel_values) + assert torch.allclose( + final_hidden_state_cls_token, + outputs.last_hidden_state[:, 0, :], + atol=1e-1, + ) + else: + logits = original_model(pixel_values) + assert logits.shape == outputs.logits.shape + assert torch.allclose(logits, outputs.logits, atol=1e-3) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dino_ijepab16", + type=str, + help="Name of the model trained with DINO you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--base_model", + action="store_true", + help="Whether to only convert the base model (no projection head weights).", + ) + + parser.set_defaults(base_model=True) + args = parser.parse_args() + convert_ijepa_checkpoint( + args.model_name, args.pytorch_dump_folder_path, args.base_model + ) diff --git a/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py b/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py new file mode 100644 index 000000000000..607e4d58462a --- /dev/null +++ b/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py @@ -0,0 +1,347 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert I-JEPA and non-distilled DeiT checkpoints from the timm library.""" + +import argparse +from pathlib import Path + +import requests +import timm +import torch +from PIL import Image +from timm.data import ImageNetInfo, infer_imagenet_subset + +from transformers import ( + DeiTImageProcessor, + IJepaConfig, + IJepaForImageClassification, + IJepaImageProcessor, + IJepaModel, +) +from transformers.utils import logging + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append( + ( + f"blocks.{i}.norm1.weight", + f"ijepa.encoder.layer.{i}.layernorm_before.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.norm1.bias", + f"ijepa.encoder.layer.{i}.layernorm_before.bias", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.attn.proj.weight", + f"ijepa.encoder.layer.{i}.attention.output.dense.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.attn.proj.bias", + f"ijepa.encoder.layer.{i}.attention.output.dense.bias", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.norm2.weight", + f"ijepa.encoder.layer.{i}.layernorm_after.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.norm2.bias", + f"ijepa.encoder.layer.{i}.layernorm_after.bias", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc1.weight", + f"ijepa.encoder.layer.{i}.intermediate.dense.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc1.bias", + f"ijepa.encoder.layer.{i}.intermediate.dense.bias", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc2.weight", + f"ijepa.encoder.layer.{i}.output.dense.weight", + ) + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc2.bias", + f"ijepa.encoder.layer.{i}.output.dense.bias", + ) + ) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("cls_token", "ijepa.embeddings.cls_token"), + ( + "patch_embed.proj.weight", + "ijepa.embeddings.patch_embeddings.projection.weight", + ), + ( + "patch_embed.proj.bias", + "ijepa.embeddings.patch_embeddings.projection.bias", + ), + ("pos_embed", "ijepa.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + rename_keys.extend( + [ + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ] + ) + + # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" + rename_keys = [ + (pair[0], pair[1][4:]) if pair[1].startswith("ijepa") else pair + for pair in rename_keys + ] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "ijepa.layernorm.weight"), + ("norm.bias", "ijepa.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "ijepa." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.query.weight" + ] = in_proj_weight[: config.hidden_size, :] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.query.bias" + ] = in_proj_bias[: config.hidden_size] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.key.weight" + ] = in_proj_weight[config.hidden_size : config.hidden_size * 2, :] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.key.bias" + ] = in_proj_bias[config.hidden_size : config.hidden_size * 2] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.value.weight" + ] = in_proj_weight[-config.hidden_size :, :] + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.value.bias" + ] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_ijepa_checkpoint(ijepa_name, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our IJEPA structure. + """ + + # define default IJEPA configuration + config = IJepaConfig() + base_model = False + + # load original model from timm + timm_model = timm.create_model(ijepa_name, pretrained=True) + timm_model.eval() + + # detect unsupported IJEPA models in transformers + # fc_norm is present + if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity): + raise ValueError( + f"{ijepa_name} is not supported in transformers because of the presence of fc_norm." + ) + + # use of global average pooling in combination (or without) class token + if getattr(timm_model, "global_pool", None) == "avg": + raise ValueError( + f"{ijepa_name} is not supported in transformers because of use of global average pooling." + ) + + # CLIP style ijepa with norm_pre layer present + if "clip" in ijepa_name and not isinstance( + getattr(timm_model, "norm_pre", None), torch.nn.Identity + ): + raise ValueError( + f"{ijepa_name} is not supported in transformers because it's a CLIP style IJEPA with norm_pre layer." + ) + + # SigLIP style ijepa with attn_pool layer present + if ( + "siglip" in ijepa_name + and getattr(timm_model, "global_pool", None) == "map" + ): + raise ValueError( + f"{ijepa_name} is not supported in transformers because it's a SigLIP style IJEPA with attn_pool." + ) + + # use of layer scale in IJEPA model blocks + if not isinstance( + getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity + ) or not isinstance( + getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity + ): + raise ValueError( + f"{ijepa_name} is not supported in transformers because it uses a layer scale in its blocks." + ) + + # Hybrid ResNet-IJEPAs + if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed): + raise ValueError( + f"{ijepa_name} is not supported in transformers because it is a hybrid ResNet-IJEPA." + ) + + # get patch size and image size from the patch embedding submodule + config.patch_size = timm_model.patch_embed.patch_size[0] + config.image_size = timm_model.patch_embed.img_size[0] + + # retrieve architecture-specific parameters from the timm model + config.hidden_size = timm_model.embed_dim + config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features + config.num_hidden_layers = len(timm_model.blocks) + config.num_attention_heads = timm_model.blocks[0].attn.num_heads + + # check whether the model has a classification head or not + if timm_model.num_classes != 0: + config.num_labels = timm_model.num_classes + # infer ImageNet subset from timm model + imagenet_subset = infer_imagenet_subset(timm_model) + dataset_info = ImageNetInfo(imagenet_subset) + config.id2label = { + i: dataset_info.index_to_label_name(i) + for i in range(dataset_info.num_classes()) + } + config.label2id = {v: k for k, v in config.id2label.items()} + else: + print( + f"{ijepa_name} is going to be converted as a feature extractor only." + ) + base_model = True + + # load state_dict of original model + state_dict = timm_model.state_dict() + + # remove and rename some keys in the state dict + if base_model: + remove_classification_head_(state_dict) + rename_keys = create_rename_keys(config, base_model) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model) + + # load HuggingFace model + if base_model: + model = IJepaModel(config, add_pooling_layer=False).eval() + else: + model = IJepaForImageClassification(config).eval() + model.load_state_dict(state_dict) + + # Check outputs on an image, prepared by IJepaImageProcessor/DeiTImageProcessor + if "deit" in ijepa_name: + image_processor = DeiTImageProcessor(size=config.image_size) + else: + image_processor = IJepaImageProcessor(size=config.image_size) + encoding = image_processor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + + if base_model: + timm_pooled_output = timm_model.forward_features(pixel_values) + assert timm_pooled_output.shape == outputs.last_hidden_state.shape + assert torch.allclose( + timm_pooled_output, outputs.last_hidden_state, atol=1e-1 + ) + else: + timm_logits = timm_model(pixel_values) + assert timm_logits.shape == outputs.logits.shape + assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {ijepa_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--ijepa_name", + default="ijepa_base_patch16_224", + type=str, + help="Name of the IJEPA timm model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + + args = parser.parse_args() + convert_ijepa_checkpoint(args.ijepa_name, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py new file mode 100644 index 000000000000..6cbe975b71c5 --- /dev/null +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -0,0 +1,932 @@ +# coding=utf-8 +# Copyright 2024 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch I-JEPA model.""" + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ( + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_ijepa import IJepaConfig + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "IJepaConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/I-JEPA" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" + + +# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings with ViT->IJEPA +class IJepaEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__( + self, config: IJepaConfig, use_mask_token: bool = False + ) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = ( + nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + if use_mask_token + else None + ) + self.patch_embeddings = IJepaPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, config.hidden_size) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, + int(math.sqrt(num_positions)), + int(math.sqrt(num_positions)), + dim, + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + h0 / math.sqrt(num_positions), + w0 / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + assert ( + int(h0) == patch_pos_embed.shape[-2] + and int(w0) == patch_pos_embed.shape[-1] + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat( + (class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1 + ) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->IJEPA +class IJepaPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, + hidden_size, + kernel_size=patch_size, + stride=patch_size, + ) + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->IJEPA +class IJepaSelfAttention(nn.Module): + def __init__(self, config: IJepaConfig) -> None: + super().__init__() + if ( + config.hidden_size % config.num_attention_heads != 0 + and not hasattr(config, "embedding_size") + ): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int( + config.hidden_size / config.num_attention_heads + ) + self.all_head_size = ( + self.num_attention_heads * self.attention_head_size + ) + + self.query = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.key = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.value = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul( + query_layer, key_layer.transpose(-1, -2) + ) + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size + ) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, + ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) + if output_attentions + else (context_layer,) + ) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->IJEPA +class IJepaSdpaSelfAttention(IJepaSelfAttention): + def __init__(self, config: IJepaConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, + ) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->IJEPA +class IJepaSelfOutput(nn.Module): + """ + The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: IJepaConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->IJEPA +class IJepaAttention(nn.Module): + def __init__(self, config: IJepaConfig) -> None: + super().__init__() + self.attention = IJepaSelfAttention(config) + self.output = IJepaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.attention.num_attention_heads, + self.attention.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = ( + self.attention.num_attention_heads - len(heads) + ) + self.attention.all_head_size = ( + self.attention.attention_head_size + * self.attention.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention( + hidden_states, head_mask, output_attentions + ) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->IJEPA +class IJepaSdpaAttention(IJepaAttention): + def __init__(self, config: IJepaConfig) -> None: + super().__init__(config) + self.attention = IJepaSdpaSelfAttention(config) + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->IJEPA +class IJepaIntermediate(nn.Module): + def __init__(self, config: IJepaConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->IJEPA +class IJepaOutput(nn.Module): + def __init__(self, config: IJepaConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +IJEPA_ATTENTION_CLASSES = { + "eager": IJepaAttention, + "sdpa": IJepaSdpaAttention, +} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with VIT->IJEPA,ViT->IJEPA +class IJepaLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: IJepaConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation]( + config + ) + self.intermediate = IJepaIntermediate(config) + self.output = IJepaOutput(config) + self.layernorm_before = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.layernorm_after = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before( + hidden_states + ), # in IJEPA, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[ + 1: + ] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in IJEPA, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->IJEPA +class IJepaEncoder(nn.Module): + def __init__(self, config: IJepaConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [IJepaLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, layer_head_mask, output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->IJEPA,vit->ijepa +class IJepaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = IJepaConfig + base_model_prefix = "ijepa" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] + _supports_sdpa = True + + def _init_weights( + self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm] + ) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, IJepaEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +IJEPA_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`IJepaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +IJEPA_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`IJepaImageProcessor.__call__`] + for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare I-JEPA Model transformer outputting raw hidden-states without any specific head on top.", + IJEPA_START_DOCSTRING, +) +# Copied from transformers.models.vit.modeling_vit.ViTModel with VIT->IJEPA,ViT->IJEPA +class IJepaModel(IJepaPreTrainedModel): + def __init__( + self, + config: IJepaConfig, + add_pooling_layer: bool = True, + use_mask_token: bool = False, + ): + super().__init__(config) + self.config = config + + self.embeddings = IJepaEmbeddings( + config, use_mask_token=use_mask_token + ) + self.encoder = IJepaEncoder(config) + + self.layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.pooler = IJepaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> IJepaPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask( + head_mask, self.config.num_hidden_layers + ) + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = ( + self.embeddings.patch_embeddings.projection.weight.dtype + ) + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, + bool_masked_pos=bool_masked_pos, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + head_outputs = ( + (sequence_output, pooled_output) + if pooled_output is not None + else (sequence_output,) + ) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->IJEPA +class IJepaPooler(nn.Module): + def __init__(self, config: IJepaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@add_start_docstrings( + """ + I-JEPA Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune I-JEPA on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + IJEPA_START_DOCSTRING, +) +# Copied from transformers.models.vit.modeling_vit.ViTForImageClassification with VIT->IJEPA,ViT->IJEPA,vit->ijepa +class IJepaForImageClassification(IJepaPreTrainedModel): + def __init__(self, config: IJepaConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.ijepa = IJepaModel(config, add_pooling_layer=False) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) + + outputs = self.ijepa( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 0f2390fb694b..f76da63e636f 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -1297,6 +1297,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxIJepaForImageClassification(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxIJepaModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxIJepaPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxWav2Vec2ForCTC(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 4732ecea8611..e91fb2f32899 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9127,6 +9127,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class IJepaForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class IJepaModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class IJepaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ViTMAEForPreTraining(metaclass=DummyObject): _backends = ["torch"] @@ -9921,7 +9942,9 @@ def get_cosine_schedule_with_warmup(*args, **kwargs): def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) + requires_backends( + get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"] + ) def get_inverse_sqrt_schedule(*args, **kwargs): diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 942a7afced4b..2454dbf49a29 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2627,6 +2627,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFIJepaForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFIJepaModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFIJepaPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFViTMAEForPreTraining(metaclass=DummyObject): _backends = ["tf"] diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 1d532aeea2a4..b89a93fbcf0a 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -14,3 +14,10 @@ class ViTImageProcessorFast(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) + + +class IJepaImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 19f8dc1b1d9c..39313bd60c62 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -632,6 +632,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class IJepaImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class VitMatteImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/ijepa/__init__.py b/tests/models/ijepa/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/ijepa/test_modeling_ijepa.py b/tests/models/ijepa/test_modeling_ijepa.py new file mode 100644 index 000000000000..ed5f448b17b1 --- /dev/null +++ b/tests/models/ijepa/test_modeling_ijepa.py @@ -0,0 +1,373 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch IJEPA model.""" + +import unittest + +from transformers import IJepaConfig +from transformers.testing_utils import ( + require_accelerate, + require_torch, + require_torch_accelerator, + require_torch_fp16, + require_vision, + slow, + torch_device, +) +from transformers.utils import ( + cached_property, + is_torch_available, + is_vision_available, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + +if is_torch_available(): + import torch + from torch import nn + + from transformers import IJepaForImageClassification, IJepaModel + + +if is_vision_available(): + from PIL import Image + + from transformers import IJepaImageProcessor + + +class IJepaModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + scope=None, + encoder_stride=2, + mask_ratio=0.5, + attn_implementation="eager", + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + self.encoder_stride = encoder_stride + self.attn_implementation = attn_implementation + + # in IJEPA, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.num_channels, + self.image_size, + self.image_size, + ] + ) + + labels = None + if self.use_labels: + labels = ids_tensor( + [self.batch_size], self.type_sequence_label_size + ) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return IJepaConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + encoder_stride=self.encoder_stride, + attn_implementation=self.attn_implementation, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = IJepaModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.seq_length, self.hidden_size), + ) + + def create_and_check_for_image_classification( + self, config, pixel_values, labels + ): + config.num_labels = self.type_sequence_label_size + model = IJepaForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual( + result.logits.shape, + (self.batch_size, self.type_sequence_label_size), + ) + + # test greyscale images + config.num_channels = 1 + model = IJepaForImageClassification(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor( + [self.batch_size, 1, self.image_size, self.image_size] + ) + result = model(pixel_values) + self.parent.assertEqual( + result.logits.shape, + (self.batch_size, self.type_sequence_label_size), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + labels, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class IJepaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as IJEPA does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = ( + ( + IJepaModel, + IJepaForImageClassification, + ) + if is_torch_available() + else () + ) + fx_compatible = False + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = IJepaModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=IJepaConfig, + has_text_modality=False, + hidden_size=37, + ) + + @unittest.skip( + "Since `torch==2.3+cu121`, although this test passes, many subsequent tests have `CUDA error: misaligned address`." + "If `nvidia-xxx-cu118` are also installed, no failure (even with `torch==2.3+cu121`)." + ) + def test_multi_gpu_data_parallel_forward(self): + super().test_multi_gpu_data_parallel_forward() + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="IJEPA does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_masked_image_modeling(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_image_modeling( + *config_and_inputs + ) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification( + *config_and_inputs + ) + + @slow + def test_model_from_pretrained(self): + model_name = "google/ijepa-base-patch16-224" + model = IJepaModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class IJepaModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return ( + IJepaImageProcessor.from_pretrained( + "google/ijepa-base-patch16-224" + ) + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head(self): + model = IJepaForImageClassification.from_pretrained( + "google/ijepa-base-patch16-224" + ).to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to( + torch_device + ) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to( + torch_device + ) + + self.assertTrue( + torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4) + ) + + @slow + def test_inference_interpolate_pos_encoding(self): + # IJEPA models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = IJepaModel.from_pretrained("facebook/dino-ijepas8").to( + torch_device + ) + + image_processor = IJepaImageProcessor.from_pretrained( + "facebook/dino-ijepas8", size=480 + ) + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 3601, 384)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [ + [4.2340, 4.3906, -6.6692], + [4.5463, 1.8928, -6.7257], + [4.4429, 0.8496, -5.8585], + ] + ).to(torch_device) + + self.assertTrue( + torch.allclose( + outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4 + ) + ) + + @slow + @require_accelerate + @require_torch_accelerator + @require_torch_fp16 + def test_inference_fp16(self): + r""" + A small test to make sure that inference work in half precision without any problem. + """ + model = IJepaModel.from_pretrained( + "facebook/dino-ijepas8", + torch_dtype=torch.float16, + device_map="auto", + ) + image_processor = self.default_image_processor + + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass to make sure inference works in fp16 + with torch.no_grad(): + _ = model(pixel_values) From b9d7c03b0729c26fa7b255bd97cd6648aed1f0c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Thu, 22 Aug 2024 15:37:30 +0200 Subject: [PATCH 02/44] add IJepaEmbeddings class --- docs/source/en/index.md | 2 +- src/transformers/__init__.py | 2 +- src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/ijepa/__init__.py | 5 +- .../models/ijepa/configuration_ijepa.py | 1 + .../models/ijepa/convert_dino_to_pytorch.py | 50 ++-- .../ijepa/convert_ijepa_timm_to_pytorch.py | 80 ++---- .../models/ijepa/modeling_ijepa.py | 232 ++++-------------- src/transformers/utils/dummy_flax_objects.py | 42 ++-- src/transformers/utils/dummy_pt_objects.py | 46 ++-- src/transformers/utils/dummy_tf_objects.py | 42 ++-- .../utils/dummy_torchvision_objects.py | 4 +- .../utils/dummy_vision_objects.py | 14 +- tests/models/ijepa/test_modeling_ijepa.py | 59 ++--- 14 files changed, 192 insertions(+), 388 deletions(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 963f88410b4f..5225ff328a5d 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -165,7 +165,7 @@ Flax), PyTorch, and/or TensorFlow. | [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ | | [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ | | [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ | -| [I-JEPA](model_doc/ijepa) | ✅ | ❌ | ❌ | +| [I-JEPA](model_doc/ijepa) | ✅ | ✅ | ✅ | | [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ | | [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ | | [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 39cefb38022b..47ff11b3bf04 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1218,7 +1218,7 @@ ] else: _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] - _import_structure["models.ijepa"].append("IJEPAImageProcessorFast") + _import_structure["models.ijepa"].append("IJepaImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") # PyTorch-backed objects diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4ff1e87a1dc1..c9d299454d10 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -630,6 +630,7 @@ ("efficientnet", "EfficientNetForImageClassification"), ("focalnet", "FocalNetForImageClassification"), ("hiera", "HieraForImageClassification"), + ("ijepa", "IJepaForImageClassification"), ("imagegpt", "ImageGPTForImageClassification"), ( "levit", diff --git a/src/transformers/models/ijepa/__init__.py b/src/transformers/models/ijepa/__init__.py index 685ef0e255c5..52ee746d156b 100644 --- a/src/transformers/models/ijepa/__init__.py +++ b/src/transformers/models/ijepa/__init__.py @@ -20,6 +20,7 @@ is_torchvision_available, ) + _import_structure = {"configuration_ijepa": ["IJepaConfig", "IJepaOnnxConfig"]} try: @@ -28,9 +29,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["image_processing_ijepa_fast"] = [ - "IJepaImageProcessorFast" - ] + _import_structure["image_processing_ijepa_fast"] = ["IJepaImageProcessorFast"] try: if not is_torch_available(): diff --git a/src/transformers/models/ijepa/configuration_ijepa.py b/src/transformers/models/ijepa/configuration_ijepa.py index 09481db69348..1bdcac4e25c3 100644 --- a/src/transformers/models/ijepa/configuration_ijepa.py +++ b/src/transformers/models/ijepa/configuration_ijepa.py @@ -23,6 +23,7 @@ from ...onnx import OnnxConfig from ...utils import logging + logger = logging.get_logger(__name__) diff --git a/src/transformers/models/ijepa/convert_dino_to_pytorch.py b/src/transformers/models/ijepa/convert_dino_to_pytorch.py index 351a4d11e179..e9da790659ba 100644 --- a/src/transformers/models/ijepa/convert_dino_to_pytorch.py +++ b/src/transformers/models/ijepa/convert_dino_to_pytorch.py @@ -31,6 +31,7 @@ ) from transformers.utils import logging + logging.set_verbosity_info() logger = logging.get_logger(__name__) @@ -127,10 +128,7 @@ def create_rename_keys(config, base_model=False): ) # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" - rename_keys = [ - (pair[0], pair[1][4:]) if pair[1].startswith("ijepa") else pair - for pair in rename_keys - ] + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("ijepa") else pair for pair in rename_keys] else: # layernorm + classification head rename_keys.extend( @@ -156,24 +154,20 @@ def read_in_q_k_v(state_dict, config, base_model=False): in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") # next, add query, keys and values (in that order) to the state dict - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.query.weight" - ] = in_proj_weight[: config.hidden_size, :] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.query.bias" - ] = in_proj_bias[: config.hidden_size] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.key.weight" - ] = in_proj_weight[config.hidden_size : config.hidden_size * 2, :] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.key.bias" - ] = in_proj_bias[config.hidden_size : config.hidden_size * 2] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.value.weight" - ] = in_proj_weight[-config.hidden_size :, :] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.value.bias" - ] = in_proj_bias[-config.hidden_size :] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] def remove_classification_head_(state_dict): @@ -195,9 +189,7 @@ def prepare_img(): @torch.no_grad() -def convert_ijepa_checkpoint( - model_name, pytorch_dump_folder_path, base_model=True -): +def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, base_model=True): """ Copy/paste/tweak model's weights to our IJEPA structure. """ @@ -212,9 +204,7 @@ def convert_ijepa_checkpoint( config.num_labels = 1000 repo_id = "huggingface/label-files" filename = "imagenet-1k-id2label.json" - id2label = json.load( - open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r") - ) + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} @@ -293,6 +283,4 @@ def convert_ijepa_checkpoint( parser.set_defaults(base_model=True) args = parser.parse_args() - convert_ijepa_checkpoint( - args.model_name, args.pytorch_dump_folder_path, args.base_model - ) + convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model) diff --git a/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py b/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py index 607e4d58462a..2ff83b964c1b 100644 --- a/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py +++ b/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py @@ -32,6 +32,7 @@ ) from transformers.utils import logging + logging.set_verbosity_info() logger = logging.get_logger(__name__) @@ -128,10 +129,7 @@ def create_rename_keys(config, base_model=False): ) # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" - rename_keys = [ - (pair[0], pair[1][4:]) if pair[1].startswith("ijepa") else pair - for pair in rename_keys - ] + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("ijepa") else pair for pair in rename_keys] else: # layernorm + classification head rename_keys.extend( @@ -157,24 +155,20 @@ def read_in_q_k_v(state_dict, config, base_model=False): in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") # next, add query, keys and values (in that order) to the state dict - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.query.weight" - ] = in_proj_weight[: config.hidden_size, :] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.query.bias" - ] = in_proj_bias[: config.hidden_size] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.key.weight" - ] = in_proj_weight[config.hidden_size : config.hidden_size * 2, :] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.key.bias" - ] = in_proj_bias[config.hidden_size : config.hidden_size * 2] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.value.weight" - ] = in_proj_weight[-config.hidden_size :, :] - state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.value.bias" - ] = in_proj_bias[-config.hidden_size :] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] def remove_classification_head_(state_dict): @@ -212,48 +206,33 @@ def convert_ijepa_checkpoint(ijepa_name, pytorch_dump_folder_path): # detect unsupported IJEPA models in transformers # fc_norm is present if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity): - raise ValueError( - f"{ijepa_name} is not supported in transformers because of the presence of fc_norm." - ) + raise ValueError(f"{ijepa_name} is not supported in transformers because of the presence of fc_norm.") # use of global average pooling in combination (or without) class token if getattr(timm_model, "global_pool", None) == "avg": - raise ValueError( - f"{ijepa_name} is not supported in transformers because of use of global average pooling." - ) + raise ValueError(f"{ijepa_name} is not supported in transformers because of use of global average pooling.") # CLIP style ijepa with norm_pre layer present - if "clip" in ijepa_name and not isinstance( - getattr(timm_model, "norm_pre", None), torch.nn.Identity - ): + if "clip" in ijepa_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity): raise ValueError( f"{ijepa_name} is not supported in transformers because it's a CLIP style IJEPA with norm_pre layer." ) # SigLIP style ijepa with attn_pool layer present - if ( - "siglip" in ijepa_name - and getattr(timm_model, "global_pool", None) == "map" - ): + if "siglip" in ijepa_name and getattr(timm_model, "global_pool", None) == "map": raise ValueError( f"{ijepa_name} is not supported in transformers because it's a SigLIP style IJEPA with attn_pool." ) # use of layer scale in IJEPA model blocks - if not isinstance( - getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity - ) or not isinstance( + if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance( getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity ): - raise ValueError( - f"{ijepa_name} is not supported in transformers because it uses a layer scale in its blocks." - ) + raise ValueError(f"{ijepa_name} is not supported in transformers because it uses a layer scale in its blocks.") # Hybrid ResNet-IJEPAs if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed): - raise ValueError( - f"{ijepa_name} is not supported in transformers because it is a hybrid ResNet-IJEPA." - ) + raise ValueError(f"{ijepa_name} is not supported in transformers because it is a hybrid ResNet-IJEPA.") # get patch size and image size from the patch embedding submodule config.patch_size = timm_model.patch_embed.patch_size[0] @@ -271,15 +250,10 @@ def convert_ijepa_checkpoint(ijepa_name, pytorch_dump_folder_path): # infer ImageNet subset from timm model imagenet_subset = infer_imagenet_subset(timm_model) dataset_info = ImageNetInfo(imagenet_subset) - config.id2label = { - i: dataset_info.index_to_label_name(i) - for i in range(dataset_info.num_classes()) - } + config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())} config.label2id = {v: k for k, v in config.id2label.items()} else: - print( - f"{ijepa_name} is going to be converted as a feature extractor only." - ) + print(f"{ijepa_name} is going to be converted as a feature extractor only.") base_model = True # load state_dict of original model @@ -312,9 +286,7 @@ def convert_ijepa_checkpoint(ijepa_name, pytorch_dump_folder_path): if base_model: timm_pooled_output = timm_model.forward_features(pixel_values) assert timm_pooled_output.shape == outputs.last_hidden_state.shape - assert torch.allclose( - timm_pooled_output, outputs.last_hidden_state, atol=1e-1 - ) + assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1) else: timm_logits = timm_model(pixel_values) assert timm_logits.shape == outputs.logits.shape diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 6cbe975b71c5..ec9ec009d7e6 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -42,6 +42,7 @@ ) from .configuration_ijepa import IJepaConfig + logger = logging.get_logger(__name__) # General docstring @@ -56,34 +57,22 @@ _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" -# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings with ViT->IJEPA class IJepaEmbeddings(nn.Module): """ - Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + Construct the position and patch embeddings. Optionally, also the mask token. """ - def __init__( - self, config: IJepaConfig, use_mask_token: bool = False - ) -> None: + def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None: super().__init__() - self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.mask_token = ( - nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - if use_mask_token - else None - ) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = IJepaPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter( - torch.randn(1, num_patches + 1, config.hidden_size) - ) + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config - def interpolate_pos_encoding( - self, embeddings: torch.Tensor, height: int, width: int - ) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. @@ -120,14 +109,9 @@ def interpolate_pos_encoding( mode="bicubic", align_corners=False, ) - assert ( - int(h0) == patch_pos_embed.shape[-2] - and int(w0) == patch_pos_embed.shape[-1] - ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat( - (class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1 - ) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward( self, @@ -135,10 +119,8 @@ def forward( bool_masked_pos: Optional[torch.BoolTensor] = None, interpolate_pos_encoding: bool = False, ) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - embeddings = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ) + batch_size, _, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) if bool_masked_pos is not None: seq_length = embeddings.shape[1] @@ -147,15 +129,9 @@ def forward( mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask - # add the [CLS] token to the embedded patch tokens - cls_tokens = self.cls_token.expand(batch_size, -1, -1) - embeddings = torch.cat((cls_tokens, embeddings), dim=1) - # add positional encoding to each token if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding( - embeddings, height, width - ) + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embeddings @@ -177,19 +153,9 @@ def __init__(self, config): image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size - image_size = ( - image_size - if isinstance(image_size, collections.abc.Iterable) - else (image_size, image_size) - ) - patch_size = ( - patch_size - if isinstance(patch_size, collections.abc.Iterable) - else (patch_size, patch_size) - ) - num_patches = (image_size[1] // patch_size[1]) * ( - image_size[0] // patch_size[0] - ) + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels @@ -227,32 +193,19 @@ def forward( class IJepaSelfAttention(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() - if ( - config.hidden_size % config.num_attention_heads != 0 - and not hasattr(config, "embedding_size") - ): + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int( - config.hidden_size / config.num_attention_heads - ) - self.all_head_size = ( - self.num_attention_heads * self.attention_head_size - ) + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) - self.key = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) - self.value = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) @@ -277,13 +230,9 @@ def forward( query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul( - query_layer, key_layer.transpose(-1, -2) - ) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt( - self.attention_head_size - ) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) @@ -299,16 +248,10 @@ def forward( context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + ( - self.all_head_size, - ) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = ( - (context_layer, attention_probs) - if output_attentions - else (context_layer,) - ) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs @@ -342,9 +285,7 @@ def forward( ) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + ( - self.all_head_size, - ) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) return context_layer, None @@ -362,9 +303,7 @@ def __init__(self, config: IJepaConfig) -> None: self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward( - self, hidden_states: torch.Tensor, input_tensor: torch.Tensor - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -396,13 +335,8 @@ def prune_heads(self, heads: Set[int]) -> None: self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads - self.attention.num_attention_heads = ( - self.attention.num_attention_heads - len(heads) - ) - self.attention.all_head_size = ( - self.attention.attention_head_size - * self.attention.num_attention_heads - ) + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( @@ -411,15 +345,11 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - self_outputs = self.attention( - hidden_states, head_mask, output_attentions - ) + self_outputs = self.attention(hidden_states, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[ - 1: - ] # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -454,9 +384,7 @@ def __init__(self, config: IJepaConfig) -> None: self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward( - self, hidden_states: torch.Tensor, input_tensor: torch.Tensor - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -479,17 +407,11 @@ def __init__(self, config: IJepaConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation]( - config - ) + self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = IJepaIntermediate(config) self.output = IJepaOutput(config) - self.layernorm_before = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.layernorm_after = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, @@ -498,16 +420,12 @@ def forward( output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_attention_outputs = self.attention( - self.layernorm_before( - hidden_states - ), # in IJEPA, layernorm is applied before self-attention + self.layernorm_before(hidden_states), # in IJEPA, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[ - 1: - ] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection hidden_states = attention_output + hidden_states @@ -529,9 +447,7 @@ class IJepaEncoder(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() self.config = config - self.layer = nn.ModuleList( - [IJepaLayer(config) for _ in range(config.num_hidden_layers)] - ) + self.layer = nn.ModuleList([IJepaLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -559,9 +475,7 @@ def forward( output_attentions, ) else: - layer_outputs = layer_module( - hidden_states, layer_head_mask, output_attentions - ) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -588,7 +502,7 @@ def forward( ) -# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->IJEPA,vit->ijepa +# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->IJepa,vit->ijepa class IJepaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -602,9 +516,7 @@ class IJepaPreTrainedModel(PreTrainedModel): _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] _supports_sdpa = True - def _init_weights( - self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm] - ) -> None: + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid @@ -684,14 +596,10 @@ def __init__( super().__init__(config) self.config = config - self.embeddings = IJepaEmbeddings( - config, use_mask_token=use_mask_token - ) + self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) self.encoder = IJepaEncoder(config) - self.layernorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pooler = IJepaPooler(config) if add_pooling_layer else None # Initialize weights and apply final processing @@ -730,21 +638,11 @@ def forward( bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict - if return_dict is not None - else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -754,14 +652,10 @@ def forward( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask( - head_mask, self.config.num_hidden_layers - ) + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) - expected_dtype = ( - self.embeddings.patch_embeddings.projection.weight.dtype - ) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype if pixel_values.dtype != expected_dtype: pixel_values = pixel_values.to(expected_dtype) @@ -780,16 +674,10 @@ def forward( ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) - pooled_output = ( - self.pooler(sequence_output) if self.pooler is not None else None - ) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - head_outputs = ( - (sequence_output, pooled_output) - if pooled_output is not None - else (sequence_output,) - ) + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) return head_outputs + encoder_outputs[1:] return BaseModelOutputWithPooling( @@ -819,7 +707,7 @@ def forward(self, hidden_states): @add_start_docstrings( """ I-JEPA Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the [CLS] token) e.g. for ImageNet. + the average pooling of the output) e.g. for ImageNet. @@ -840,11 +728,7 @@ def __init__(self, config: IJepaConfig) -> None: self.ijepa = IJepaModel(config, add_pooling_layer=False) # Classifier head - self.classifier = ( - nn.Linear(config.hidden_size, config.num_labels) - if config.num_labels > 0 - else nn.Identity() - ) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() # Initialize weights and apply final processing self.post_init() @@ -872,11 +756,7 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = ( - return_dict - if return_dict is not None - else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ijepa( pixel_values, @@ -898,9 +778,7 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -913,9 +791,7 @@ def forward( loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1) - ) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index f76da63e636f..cb8f0e8ff047 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -842,6 +842,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxIJepaForImageClassification(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxIJepaModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxIJepaPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxLlamaForCausalLM(metaclass=DummyObject): _backends = ["flax"] @@ -1297,27 +1318,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxIJepaForImageClassification(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxIJepaModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxIJepaPreTrainedModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - class FlaxWav2Vec2ForCTC(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e91fb2f32899..acb4b03b2d01 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4852,6 +4852,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class IJepaForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class IJepaModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class IJepaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ImageGPTForCausalImageModeling(metaclass=DummyObject): _backends = ["torch"] @@ -9127,27 +9148,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class IJepaForImageClassification(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - -class IJepaModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - -class IJepaPreTrainedModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class ViTMAEForPreTraining(metaclass=DummyObject): _backends = ["torch"] @@ -9942,9 +9942,7 @@ def get_cosine_schedule_with_warmup(*args, **kwargs): def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): - requires_backends( - get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"] - ) + requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) def get_inverse_sqrt_schedule(*args, **kwargs): diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 2454dbf49a29..7663a6d652d5 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1563,6 +1563,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFIJepaForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFIJepaModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFIJepaPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFLayoutLMForMaskedLM(metaclass=DummyObject): _backends = ["tf"] @@ -2627,27 +2648,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) -class TFIJepaForImageClassification(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFIJepaModel(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFIJepaPreTrainedModel(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - class TFViTMAEForPreTraining(metaclass=DummyObject): _backends = ["tf"] diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index b89a93fbcf0a..38fbbf313c2b 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -9,14 +9,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) -class ViTImageProcessorFast(metaclass=DummyObject): +class IJepaImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) -class IJepaImageProcessorFast(metaclass=DummyObject): +class ViTImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] def __init__(self, *args, **kwargs): diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 39313bd60c62..ffbae4164329 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -296,6 +296,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class IJepaImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class ImageGPTFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] @@ -632,13 +639,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) -class IJepaImageProcessor(metaclass=DummyObject): - _backends = ["vision"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["vision"]) - - class VitMatteImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/ijepa/test_modeling_ijepa.py b/tests/models/ijepa/test_modeling_ijepa.py index ed5f448b17b1..a74fd77f18ca 100644 --- a/tests/models/ijepa/test_modeling_ijepa.py +++ b/tests/models/ijepa/test_modeling_ijepa.py @@ -36,6 +36,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin + if is_torch_available(): import torch from torch import nn @@ -112,9 +113,7 @@ def prepare_config_and_inputs(self): labels = None if self.use_labels: - labels = ids_tensor( - [self.batch_size], self.type_sequence_label_size - ) + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) config = self.get_config() @@ -148,9 +147,7 @@ def create_and_check_model(self, config, pixel_values, labels): (self.batch_size, self.seq_length, self.hidden_size), ) - def create_and_check_for_image_classification( - self, config, pixel_values, labels - ): + def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size model = IJepaForImageClassification(config) model.to(torch_device) @@ -167,9 +164,7 @@ def create_and_check_for_image_classification( model.to(torch_device) model.eval() - pixel_values = floats_tensor( - [self.batch_size, 1, self.image_size, self.image_size] - ) + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) self.parent.assertEqual( result.logits.shape, @@ -246,15 +241,11 @@ def test_model(self): def test_for_masked_image_modeling(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_for_masked_image_modeling( - *config_and_inputs - ) + self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_for_image_classification( - *config_and_inputs - ) + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) @slow def test_model_from_pretrained(self): @@ -274,25 +265,15 @@ def prepare_img(): class IJepaModelIntegrationTest(unittest.TestCase): @cached_property def default_image_processor(self): - return ( - IJepaImageProcessor.from_pretrained( - "google/ijepa-base-patch16-224" - ) - if is_vision_available() - else None - ) + return IJepaImageProcessor.from_pretrained("google/ijepa-base-patch16-224") if is_vision_available() else None @slow def test_inference_image_classification_head(self): - model = IJepaForImageClassification.from_pretrained( - "google/ijepa-base-patch16-224" - ).to(torch_device) + model = IJepaForImageClassification.from_pretrained("google/ijepa-base-patch16-224").to(torch_device) image_processor = self.default_image_processor image = prepare_img() - inputs = image_processor(images=image, return_tensors="pt").to( - torch_device - ) + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) # forward pass with torch.no_grad(): @@ -302,13 +283,9 @@ def test_inference_image_classification_head(self): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to( - torch_device - ) + expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device) - self.assertTrue( - torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4) - ) + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) @slow def test_inference_interpolate_pos_encoding(self): @@ -316,13 +293,9 @@ def test_inference_interpolate_pos_encoding(self): # allowing to interpolate the pre-trained position embeddings in order to use # the model on higher resolutions. The DINO model by Facebook AI leverages this # to visualize self-attention on higher resolution images. - model = IJepaModel.from_pretrained("facebook/dino-ijepas8").to( - torch_device - ) + model = IJepaModel.from_pretrained("facebook/dino-ijepas8").to(torch_device) - image_processor = IJepaImageProcessor.from_pretrained( - "facebook/dino-ijepas8", size=480 - ) + image_processor = IJepaImageProcessor.from_pretrained("facebook/dino-ijepas8", size=480) image = prepare_img() inputs = image_processor(images=image, return_tensors="pt") pixel_values = inputs.pixel_values.to(torch_device) @@ -343,11 +316,7 @@ def test_inference_interpolate_pos_encoding(self): ] ).to(torch_device) - self.assertTrue( - torch.allclose( - outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4 - ) - ) + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) @slow @require_accelerate From 7af89611cf991b38c92f8c4af36ba27991fb77e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Thu, 22 Aug 2024 16:31:54 +0200 Subject: [PATCH 03/44] fix copy-from for IJepa model --- .../models/ijepa/modeling_ijepa.py | 88 ++++++------------- 1 file changed, 25 insertions(+), 63 deletions(-) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index ec9ec009d7e6..d250d5b170f1 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -140,7 +140,7 @@ def forward( return embeddings -# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->IJepa class IJepaPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial @@ -161,18 +161,9 @@ def __init__(self, config): self.num_channels = num_channels self.num_patches = num_patches - self.projection = nn.Conv2d( - num_channels, - hidden_size, - kernel_size=patch_size, - stride=patch_size, - ) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward( - self, - pixel_values: torch.Tensor, - interpolate_pos_encoding: bool = False, - ) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -189,7 +180,7 @@ def forward( return embeddings -# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->IJepa class IJepaSelfAttention(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -210,18 +201,12 @@ def __init__(self, config: IJepaConfig) -> None: self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( - self, - hidden_states, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) @@ -256,17 +241,14 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->IJepa class IJepaSdpaSelfAttention(IJepaSelfAttention): def __init__(self, config: IJepaConfig) -> None: super().__init__(config) self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( - self, - hidden_states, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) @@ -291,7 +273,7 @@ def forward( return context_layer, None -# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->IJepa class IJepaSelfOutput(nn.Module): """ The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the @@ -310,7 +292,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->IJepa class IJepaAttention(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -322,10 +304,7 @@ def prune_heads(self, heads: Set[int]) -> None: if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( - heads, - self.attention.num_attention_heads, - self.attention.attention_head_size, - self.pruned_heads, + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads ) # Prune linear layers @@ -353,14 +332,14 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->IJepa class IJepaSdpaAttention(IJepaAttention): def __init__(self, config: IJepaConfig) -> None: super().__init__(config) self.attention = IJepaSdpaSelfAttention(config) -# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->IJepa class IJepaIntermediate(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -377,7 +356,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->IJepa class IJepaOutput(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -399,7 +378,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to } -# Copied from transformers.models.vit.modeling_vit.ViTLayer with VIT->IJEPA,ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTLayer with VIT->IJEPA,ViT->IJepa class IJepaLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -420,7 +399,7 @@ def forward( output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_attention_outputs = self.attention( - self.layernorm_before(hidden_states), # in IJEPA, layernorm is applied before self-attention + self.layernorm_before(hidden_states), # in IJepa, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) @@ -430,7 +409,7 @@ def forward( # first residual connection hidden_states = attention_output + hidden_states - # in IJEPA, layernorm is also applied after self-attention + # in IJepa, layernorm is also applied after self-attention layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) @@ -442,7 +421,7 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->IJepa class IJepaEncoder(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -486,15 +465,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple( - v - for v in [ - hidden_states, - all_hidden_states, - all_self_attentions, - ] - if v is not None - ) + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, @@ -522,9 +493,7 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() @@ -585,14 +554,9 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No "The bare I-JEPA Model transformer outputting raw hidden-states without any specific head on top.", IJEPA_START_DOCSTRING, ) -# Copied from transformers.models.vit.modeling_vit.ViTModel with VIT->IJEPA,ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTModel with VIT->IJEPA,ViT->IJepa class IJepaModel(IJepaPreTrainedModel): - def __init__( - self, - config: IJepaConfig, - add_pooling_layer: bool = True, - use_mask_token: bool = False, - ): + def __init__(self, config: IJepaConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): super().__init__(config) self.config = config @@ -660,9 +624,7 @@ def forward( pixel_values = pixel_values.to(expected_dtype) embedding_output = self.embeddings( - pixel_values, - bool_masked_pos=bool_masked_pos, - interpolate_pos_encoding=interpolate_pos_encoding, + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) encoder_outputs = self.encoder( @@ -688,7 +650,7 @@ def forward( ) -# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->IJEPA +# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->IJepa class IJepaPooler(nn.Module): def __init__(self, config: IJepaConfig): super().__init__() @@ -719,7 +681,7 @@ def forward(self, hidden_states): """, IJEPA_START_DOCSTRING, ) -# Copied from transformers.models.vit.modeling_vit.ViTForImageClassification with VIT->IJEPA,ViT->IJEPA,vit->ijepa +# Copied from transformers.models.vit.modeling_vit.ViTForImageClassification with VIT->IJEPA,ViT->IJepa,vit->ijepa class IJepaForImageClassification(IJepaPreTrainedModel): def __init__(self, config: IJepaConfig) -> None: super().__init__(config) From a4c8eece49189ec2889030b2514fb65ffadcec04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 16:37:26 +0200 Subject: [PATCH 04/44] add weight conversion script --- .../models/ijepa/convert_dino_to_pytorch.py | 225 +++++++++--------- .../models/ijepa/modeling_ijepa.py | 49 +--- 2 files changed, 119 insertions(+), 155 deletions(-) diff --git a/src/transformers/models/ijepa/convert_dino_to_pytorch.py b/src/transformers/models/ijepa/convert_dino_to_pytorch.py index e9da790659ba..e4a493a03b57 100644 --- a/src/transformers/models/ijepa/convert_dino_to_pytorch.py +++ b/src/transformers/models/ijepa/convert_dino_to_pytorch.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert IJEPA checkpoints trained with the DINO method.""" +"""Convert IJEPA checkpoints from the original repository. + +URL: https://github.com/facebookresearch/ijepa +""" import argparse -import json from pathlib import Path import requests @@ -25,8 +27,7 @@ from transformers import ( IJepaConfig, - IJepaForImageClassification, - IJepaImageProcessor, + ViTImageProcessor, IJepaModel, ) from transformers.utils import logging @@ -37,8 +38,14 @@ # here we list all keys to be renamed (original name on the left, our name on the right) -def create_rename_keys(config, base_model=False): +def create_rename_keys(config): rename_keys = [] + + # projection layer + position embeddings + rename_keys.append(("pos_embed", "ijepa.embeddings.position_embeddings")) + rename_keys.append(("patch_embed.proj.weight", "ijepa.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("patch_embed.proj.bias", "ijepa.embeddings.patch_embeddings.projection.bias")) + for i in range(config.num_hidden_layers): # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms rename_keys.append( @@ -102,72 +109,48 @@ def create_rename_keys(config, base_model=False): ) ) - # projection layer + position embeddings + # layernorm + pooler rename_keys.extend( [ - ("cls_token", "ijepa.embeddings.cls_token"), - ( - "patch_embed.proj.weight", - "ijepa.embeddings.patch_embeddings.projection.weight", - ), - ( - "patch_embed.proj.bias", - "ijepa.embeddings.patch_embeddings.projection.bias", - ), - ("pos_embed", "ijepa.embeddings.position_embeddings"), + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), ] ) - if base_model: - # layernorm + pooler - rename_keys.extend( - [ - ("norm.weight", "layernorm.weight"), - ("norm.bias", "layernorm.bias"), - ] - ) - - # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" - rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("ijepa") else pair for pair in rename_keys] - else: - # layernorm + classification head - rename_keys.extend( - [ - ("norm.weight", "ijepa.layernorm.weight"), - ("norm.bias", "ijepa.layernorm.bias"), - ("head.weight", "classifier.weight"), - ("head.bias", "classifier.bias"), - ] - ) + # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" + rename_keys = [ + (pair[0], pair[1][6:]) if pair[1].startswith("ijepa") else pair + for pair in rename_keys + ] return rename_keys # we split up the matrix of each encoder layer into queries, keys and values -def read_in_q_k_v(state_dict, config, base_model=False): +def read_in_q_k_v(state_dict, config): for i in range(config.num_hidden_layers): - if base_model: - prefix = "" - else: - prefix = "ijepa." # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") # next, add query, keys and values (in that order) to the state dict - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ - : config.hidden_size, : - ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ - config.hidden_size : config.hidden_size * 2, : - ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ - config.hidden_size : config.hidden_size * 2 - ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ - -config.hidden_size :, : - ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = ( + in_proj_weight[: config.hidden_size, :] + ) + state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = ( + in_proj_bias[: config.hidden_size] + ) + state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = ( + in_proj_weight[config.hidden_size : config.hidden_size * 2, :] + ) + state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = ( + in_proj_bias[config.hidden_size : config.hidden_size * 2] + ) + state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = ( + in_proj_weight[-config.hidden_size :, :] + ) + state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = ( + in_proj_bias[-config.hidden_size :] + ) def remove_classification_head_(state_dict): @@ -188,76 +171,83 @@ def prepare_img(): return im +def get_ijepa_config(model_name): + patch_size = int(model_name.split("_")[1][4:]) + config = IJepaConfig(patch_size=patch_size) + if "vith" in model_name: + config.hidden_size = 1280 + config.num_hidden_layers = 32 + config.num_attention_heads = 16 + config.layer_norm_eps = 1e-6 + config.mlp_ratio = 4 + config.intermediate_size = 5120 + elif "vitg" in model_name: + config.hidden_size = 1408 + config.num_hidden_layers = 40 + config.num_attention_heads = 16 + config.layer_norm_eps = 1e-6 + config.mlp_ratio = 48 / 11 + config.intermediate_size = 6144 + else: + raise ValueError("Model not supported, only supports huge and giant models.") + return config + + @torch.no_grad() -def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, base_model=True): +def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our IJEPA structure. """ # define default IJEPA configuration - config = IJepaConfig() - # patch_size - if model_name[-1] == "8": - config.patch_size = 8 - # set labels if required - if not base_model: - config.num_labels = 1000 - repo_id = "huggingface/label-files" - filename = "imagenet-1k-id2label.json" - id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) - id2label = {int(k): v for k, v in id2label.items()} - config.id2label = id2label - config.label2id = {v: k for k, v in id2label.items()} - # size of the architecture - if model_name in ["dino_ijepas8", "dino_ijepas16"]: - config.hidden_size = 384 - config.intermediate_size = 1536 - config.num_hidden_layers = 12 - config.num_attention_heads = 6 - - # load original model from torch hub - original_model = torch.hub.load("facebookresearch/dino:main", model_name) - original_model.eval() - - # load state_dict of original model, remove and rename some keys - state_dict = original_model.state_dict() - if base_model: - remove_classification_head_(state_dict) - rename_keys = create_rename_keys(config, base_model=base_model) + config = get_ijepa_config(model_name) + + checkpoint_mapping = { + "ijepa_vith14_1k": "https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar", + "ijepa_vith14_22k": "https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar", + "ijepa_vith16_1k": "https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar", + "ijepa_vitg16_22k": "https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar", + } + + # Load original checkpoint + checkpoint_url = checkpoint_mapping[model_name] + original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["encoder"] + original_state_dict = {k.replace("module.", ""): v for k, v in original_state_dict.items()} + + # Rename keys + state_dict = original_state_dict.copy() + remove_classification_head_(state_dict) + rename_keys = create_rename_keys(config) for src, dest in rename_keys: rename_key(state_dict, src, dest) - read_in_q_k_v(state_dict, config, base_model) + read_in_q_k_v(state_dict, config) # load HuggingFace model - if base_model: - model = IJepaModel(config, add_pooling_layer=False).eval() - else: - model = IJepaForImageClassification(config).eval() + model = IJepaModel(config, add_pooling_layer=False).eval() model.load_state_dict(state_dict) # Check outputs on an image, prepared by IJepaImageProcessor - image_processor = IJepaImageProcessor() + image_processor = ViTImageProcessor() encoding = image_processor(images=prepare_img(), return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) - if base_model: - final_hidden_state_cls_token = original_model(pixel_values) - assert torch.allclose( - final_hidden_state_cls_token, - outputs.last_hidden_state[:, 0, :], - atol=1e-1, - ) - else: - logits = original_model(pixel_values) - assert logits.shape == outputs.logits.shape - assert torch.allclose(logits, outputs.logits, atol=1e-3) + expected_slice = torch.Tensor([[-0.0621, -0.0054, -2.7513], + [-0.1952, 0.0909, -3.9536], + [0.0942, -0.0331, -1.2833]]) - Path(pytorch_dump_folder_path).mkdir(exist_ok=True) - print(f"Saving model {model_name} to {pytorch_dump_folder_path}") - model.save_pretrained(pytorch_dump_folder_path) - print(f"Saving image processor to {pytorch_dump_folder_path}") - image_processor.save_pretrained(pytorch_dump_folder_path) + assert torch.allclose( + expected_slice, + outputs.last_hidden_state[0, :3, :3], + atol=1e-4, + ) + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) if __name__ == "__main__": @@ -265,9 +255,15 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, base_model=Tr # Required parameters parser.add_argument( "--model_name", - default="dino_ijepab16", + default="ijepa_vith14_1k", type=str, - help="Name of the model trained with DINO you'd like to convert.", + choices=[ + "ijepa_vith14_1k", + "ijepa_vith14_22k", + "ijepa_vith16_1k", + "ijepa_vitg16_22k", + ], + help="Name of the model you'd like to convert.", ) parser.add_argument( "--pytorch_dump_folder_path", @@ -275,12 +271,7 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, base_model=Tr type=str, help="Path to the output PyTorch model directory.", ) - parser.add_argument( - "--base_model", - action="store_true", - help="Whether to only convert the base model (no projection head weights).", - ) - parser.set_defaults(base_model=True) + parser.set_defaults() args = parser.parse_args() - convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model) + convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index d250d5b170f1..1636bf01d59f 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -68,7 +68,7 @@ def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None: self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = IJepaPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config @@ -117,10 +117,9 @@ def forward( self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, - interpolate_pos_encoding: bool = False, ) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape - embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + embeddings = self.patch_embeddings(pixel_values) if bool_masked_pos is not None: seq_length = embeddings.shape[1] @@ -130,17 +129,13 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embeddings + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) embeddings = self.dropout(embeddings) return embeddings -# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->IJepa class IJepaPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial @@ -163,19 +158,18 @@ def __init__(self, config): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) - if not interpolate_pos_encoding: - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings @@ -378,7 +372,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to } -# Copied from transformers.models.vit.modeling_vit.ViTLayer with VIT->IJEPA,ViT->IJepa +# Copied from transformers.models.vit.modeling_vit.ViTLayer with VIT->IJepa,ViT->IJepa class IJepaLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -473,7 +467,6 @@ def forward( ) -# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->IJepa,vit->ijepa class IJepaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -507,12 +500,6 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.position_embeddings.dtype) - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - IJEPA_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it @@ -543,8 +530,6 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - interpolate_pos_encoding (`bool`, *optional*): - Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -554,7 +539,6 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No "The bare I-JEPA Model transformer outputting raw hidden-states without any specific head on top.", IJEPA_START_DOCSTRING, ) -# Copied from transformers.models.vit.modeling_vit.ViTModel with VIT->IJEPA,ViT->IJepa class IJepaModel(IJepaPreTrainedModel): def __init__(self, config: IJepaConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): super().__init__(config) @@ -595,7 +579,6 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -624,7 +607,7 @@ def forward( pixel_values = pixel_values.to(expected_dtype) embedding_output = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + pixel_values, bool_masked_pos=bool_masked_pos, ) encoder_outputs = self.encoder( @@ -671,17 +654,9 @@ def forward(self, hidden_states): I-JEPA Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the average pooling of the output) e.g. for ImageNet. - - - Note that it's possible to fine-tune I-JEPA on higher resolution images than the ones it has been trained on, by - setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained - position embeddings to the higher resolution. - - """, IJEPA_START_DOCSTRING, ) -# Copied from transformers.models.vit.modeling_vit.ViTForImageClassification with VIT->IJEPA,ViT->IJepa,vit->ijepa class IJepaForImageClassification(IJepaPreTrainedModel): def __init__(self, config: IJepaConfig) -> None: super().__init__(config) @@ -709,7 +684,6 @@ def forward( labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutput]: r""" @@ -725,7 +699,6 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) From bf70f98a1356e4b885600eec4cd821898c99c6b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 16:40:29 +0200 Subject: [PATCH 05/44] update attention class names in IJepa model --- src/transformers/models/ijepa/modeling_ijepa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 1636bf01d59f..7ddb1c4cdd9d 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -366,7 +366,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -IJEPA_ATTENTION_CLASSES = { +IJepa_ATTENTION_CLASSES = { "eager": IJepaAttention, "sdpa": IJepaSdpaAttention, } @@ -380,7 +380,7 @@ def __init__(self, config: IJepaConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = IJepa_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = IJepaIntermediate(config) self.output = IJepaOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) From 64f2208fbb0c5165f08cad6a1814162e01baa6f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 16:41:05 +0200 Subject: [PATCH 06/44] style changes --- .../models/ijepa/convert_dino_to_pytorch.py | 42 +++++++------------ .../models/ijepa/modeling_ijepa.py | 3 +- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/ijepa/convert_dino_to_pytorch.py b/src/transformers/models/ijepa/convert_dino_to_pytorch.py index e4a493a03b57..1fb2a93ba0d7 100644 --- a/src/transformers/models/ijepa/convert_dino_to_pytorch.py +++ b/src/transformers/models/ijepa/convert_dino_to_pytorch.py @@ -22,13 +22,12 @@ import requests import torch -from huggingface_hub import hf_hub_download from PIL import Image from transformers import ( IJepaConfig, - ViTImageProcessor, IJepaModel, + ViTImageProcessor, ) from transformers.utils import logging @@ -118,10 +117,7 @@ def create_rename_keys(config): ) # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" - rename_keys = [ - (pair[0], pair[1][6:]) if pair[1].startswith("ijepa") else pair - for pair in rename_keys - ] + rename_keys = [(pair[0], pair[1][6:]) if pair[1].startswith("ijepa") else pair for pair in rename_keys] return rename_keys @@ -133,24 +129,16 @@ def read_in_q_k_v(state_dict, config): in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") # next, add query, keys and values (in that order) to the state dict - state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = ( - in_proj_weight[: config.hidden_size, :] - ) - state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = ( - in_proj_bias[: config.hidden_size] - ) - state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = ( - in_proj_weight[config.hidden_size : config.hidden_size * 2, :] - ) - state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = ( - in_proj_bias[config.hidden_size : config.hidden_size * 2] - ) - state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = ( - in_proj_weight[-config.hidden_size :, :] - ) - state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = ( - in_proj_bias[-config.hidden_size :] - ) + state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :] + state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :] + state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] def remove_classification_head_(state_dict): @@ -232,9 +220,9 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path): pixel_values = encoding["pixel_values"] outputs = model(pixel_values) - expected_slice = torch.Tensor([[-0.0621, -0.0054, -2.7513], - [-0.1952, 0.0909, -3.9536], - [0.0942, -0.0331, -1.2833]]) + expected_slice = torch.Tensor( + [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]] + ) assert torch.allclose( expected_slice, diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 7ddb1c4cdd9d..1395bb0d5f68 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -607,7 +607,8 @@ def forward( pixel_values = pixel_values.to(expected_dtype) embedding_output = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, + pixel_values, + bool_masked_pos=bool_masked_pos, ) encoder_outputs = self.encoder( From 1dd4e7dde3470433fe111bc41184e6c57d4efa29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 18:02:26 +0200 Subject: [PATCH 07/44] Add push_to_hub option to convert_ijepa_checkpoint function --- .../models/ijepa/convert_dino_to_pytorch.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/ijepa/convert_dino_to_pytorch.py b/src/transformers/models/ijepa/convert_dino_to_pytorch.py index 1fb2a93ba0d7..22ccfb1a0036 100644 --- a/src/transformers/models/ijepa/convert_dino_to_pytorch.py +++ b/src/transformers/models/ijepa/convert_dino_to_pytorch.py @@ -182,7 +182,7 @@ def get_ijepa_config(model_name): @torch.no_grad() -def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path): +def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): """ Copy/paste/tweak model's weights to our IJEPA structure. """ @@ -237,6 +237,16 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path): print(f"Saving image processor to {pytorch_dump_folder_path}") image_processor.save_pretrained(pytorch_dump_folder_path) + if push_to_hub: + model_name_to_hf_name = { + "ijepa_vith14_1k": "ijepa_huge_patch14_1k", + "ijepa_vith14_22k": "ijepa_huge_patch14_22k", + "ijepa_vith16_1k": "ijepa_huge_patch16_1k", + "ijepa_vitg16_22k": "ijepa_giant_patch16_22k", + } + name = model_name_to_hf_name[model_name] + model.push_to_hub(f"jmtzt/{name}", use_temp_dir=True) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -259,7 +269,12 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path): type=str, help="Path to the output PyTorch model directory.", ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) parser.set_defaults() args = parser.parse_args() - convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path) + convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) From 9826f9985fa4ff17d3819c0842595c45150200f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 18:03:00 +0200 Subject: [PATCH 08/44] add initial tests for I-JEPA --- tests/models/ijepa/test_modeling_ijepa.py | 62 +++++++---------------- 1 file changed, 17 insertions(+), 45 deletions(-) diff --git a/tests/models/ijepa/test_modeling_ijepa.py b/tests/models/ijepa/test_modeling_ijepa.py index a74fd77f18ca..371db004c62e 100644 --- a/tests/models/ijepa/test_modeling_ijepa.py +++ b/tests/models/ijepa/test_modeling_ijepa.py @@ -47,7 +47,7 @@ if is_vision_available(): from PIL import Image - from transformers import IJepaImageProcessor + from transformers import ViTImageProcessor class IJepaModelTester: @@ -94,9 +94,9 @@ def __init__( self.encoder_stride = encoder_stride self.attn_implementation = attn_implementation - # in IJEPA, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + # in IJEPA, the seq length equals the number of patches (we don't add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 - self.seq_length = num_patches + 1 + self.seq_length = num_patches self.mask_ratio = mask_ratio self.num_masks = int(mask_ratio * self.seq_length) self.mask_length = num_patches @@ -197,6 +197,11 @@ class IJepaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): if is_torch_available() else () ) + pipeline_model_mapping = ( + {"image-feature-extraction": IJepaModel, "image-classification": IJepaForImageClassification} + if is_torch_available() + else {} + ) fx_compatible = False test_pruning = False @@ -239,17 +244,13 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_image_modeling(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) - def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) @slow def test_model_from_pretrained(self): - model_name = "google/ijepa-base-patch16-224" + model_name = "jmtzt/ijepa_huge_patch14_1k" model = IJepaModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -265,11 +266,11 @@ def prepare_img(): class IJepaModelIntegrationTest(unittest.TestCase): @cached_property def default_image_processor(self): - return IJepaImageProcessor.from_pretrained("google/ijepa-base-patch16-224") if is_vision_available() else None + return ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None @slow - def test_inference_image_classification_head(self): - model = IJepaForImageClassification.from_pretrained("google/ijepa-base-patch16-224").to(torch_device) + def test_inference_no_head(self): + model = IJepaModel.from_pretrained("jmtzt/ijepa_huge_patch14_1k").to(torch_device) image_processor = self.default_image_processor image = prepare_img() @@ -279,41 +280,12 @@ def test_inference_image_classification_head(self): with torch.no_grad(): outputs = model(**inputs) - # verify the logits - expected_shape = torch.Size((1, 1000)) - self.assertEqual(outputs.logits.shape, expected_shape) - - expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device) - - self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) - - @slow - def test_inference_interpolate_pos_encoding(self): - # IJEPA models have an `interpolate_pos_encoding` argument in their forward method, - # allowing to interpolate the pre-trained position embeddings in order to use - # the model on higher resolutions. The DINO model by Facebook AI leverages this - # to visualize self-attention on higher resolution images. - model = IJepaModel.from_pretrained("facebook/dino-ijepas8").to(torch_device) - - image_processor = IJepaImageProcessor.from_pretrained("facebook/dino-ijepas8", size=480) - image = prepare_img() - inputs = image_processor(images=image, return_tensors="pt") - pixel_values = inputs.pixel_values.to(torch_device) - - # forward pass - with torch.no_grad(): - outputs = model(pixel_values, interpolate_pos_encoding=True) - - # verify the logits - expected_shape = torch.Size((1, 3601, 384)) + # verify the last hidden state + expected_shape = torch.Size((1, 256, 1280)) self.assertEqual(outputs.last_hidden_state.shape, expected_shape) - expected_slice = torch.tensor( - [ - [4.2340, 4.3906, -6.6692], - [4.5463, 1.8928, -6.7257], - [4.4429, 0.8496, -5.8585], - ] + expected_slice = torch.Tensor( + [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]] ).to(torch_device) self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) @@ -327,7 +299,7 @@ def test_inference_fp16(self): A small test to make sure that inference work in half precision without any problem. """ model = IJepaModel.from_pretrained( - "facebook/dino-ijepas8", + "jmtzt/ijepa_huge_patch14_1k", torch_dtype=torch.float16, device_map="auto", ) From d78e468ee8282283de7f85afbbf03c60a77790a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 18:04:06 +0200 Subject: [PATCH 09/44] minor style changes to conversion script --- .../models/ijepa/convert_dino_to_pytorch.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/ijepa/convert_dino_to_pytorch.py b/src/transformers/models/ijepa/convert_dino_to_pytorch.py index 22ccfb1a0036..81bf417a0c0c 100644 --- a/src/transformers/models/ijepa/convert_dino_to_pytorch.py +++ b/src/transformers/models/ijepa/convert_dino_to_pytorch.py @@ -239,11 +239,11 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): if push_to_hub: model_name_to_hf_name = { - "ijepa_vith14_1k": "ijepa_huge_patch14_1k", - "ijepa_vith14_22k": "ijepa_huge_patch14_22k", - "ijepa_vith16_1k": "ijepa_huge_patch16_1k", - "ijepa_vitg16_22k": "ijepa_giant_patch16_22k", - } + "ijepa_vith14_1k": "ijepa_huge_patch14_1k", + "ijepa_vith14_22k": "ijepa_huge_patch14_22k", + "ijepa_vith16_1k": "ijepa_huge_patch16_1k", + "ijepa_vitg16_22k": "ijepa_giant_patch16_22k", + } name = model_name_to_hf_name[model_name] model.push_to_hub(f"jmtzt/{name}", use_temp_dir=True) From 7a64b83fff5bd6fd6017647dc498fbca6b2b8208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 19:41:20 +0200 Subject: [PATCH 10/44] make fixup related --- docs/source/en/index.md | 2 +- src/transformers/__init__.py | 27 -- .../models/auto/feature_extraction_auto.py | 1 - .../models/auto/image_processing_auto.py | 1 + src/transformers/models/ijepa/__init__.py | 27 +- .../models/ijepa/configuration_ijepa.py | 4 - .../models/ijepa/convert_dino_to_pytorch.py | 2 +- .../ijepa/convert_ijepa_timm_to_pytorch.py | 319 ------------------ .../models/ijepa/modeling_ijepa.py | 28 +- src/transformers/utils/dummy_flax_objects.py | 21 -- src/transformers/utils/dummy_tf_objects.py | 21 -- .../utils/dummy_torchvision_objects.py | 7 - .../utils/dummy_vision_objects.py | 7 - 13 files changed, 31 insertions(+), 436 deletions(-) delete mode 100644 src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 5225ff328a5d..963f88410b4f 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -165,7 +165,7 @@ Flax), PyTorch, and/or TensorFlow. | [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ | | [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ | | [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ | -| [I-JEPA](model_doc/ijepa) | ✅ | ✅ | ✅ | +| [I-JEPA](model_doc/ijepa) | ✅ | ❌ | ❌ | | [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ | | [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ | | [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 47ff11b3bf04..c589c8de134a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1218,7 +1218,6 @@ ] else: _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] - _import_structure["models.ijepa"].append("IJepaImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") # PyTorch-backed objects @@ -4088,13 +4087,6 @@ ] ) - _import_structure["models.ijepa"].extend( - [ - "TFIJepaForImageClassification", - "TFIJepaModel", - "TFIJepaPreTrainedModel", - ] - ) _import_structure["models.layoutlm"].extend( [ "TFLayoutLMForMaskedLM", @@ -4755,13 +4747,6 @@ _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) - _import_structure["models.ijepa"].extend( - [ - "FlaxIJepaForImageClassification", - "FlaxIJepaModel", - "FlaxIJepaPreTrainedModel", - ] - ) _import_structure["models.wav2vec2"].extend( [ "FlaxWav2Vec2ForCTC", @@ -5965,7 +5950,6 @@ from .models.grounding_dino import GroundingDinoImageProcessor from .models.idefics import IdeficsImageProcessor from .models.idefics2 import Idefics2ImageProcessor - from .models.ijepa import IJepaImageProcessor from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor from .models.instructblipvideo import InstructBlipVideoImageProcessor from .models.layoutlmv2 import ( @@ -6028,7 +6012,6 @@ from .utils.dummy_torchvision_objects import * else: from .image_processing_utils_fast import BaseImageProcessorFast - from .models.ijepa import IJepaImageProcessorFast from .models.vit import ViTImageProcessorFast # Modeling @@ -8338,11 +8321,6 @@ TFIdeficsModel, TFIdeficsPreTrainedModel, ) - from .models.ijepa import ( - TFIJepaForImageClassification, - TFIJepaModel, - TFIJepaPreTrainedModel, - ) from .models.layoutlm import ( TFLayoutLMForMaskedLM, TFLayoutLMForQuestionAnswering, @@ -8810,11 +8788,6 @@ FlaxGPTJModel, FlaxGPTJPreTrainedModel, ) - from .models.ijepa import ( - FlaxIJepaForImageClassification, - FlaxIJepaModel, - FlaxIJepaPreTrainedModel, - ) from .models.llama import ( FlaxLlamaForCausalLM, FlaxLlamaModel, diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index bba05458851d..7f335d66584f 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -63,7 +63,6 @@ ("glpn", "GLPNFeatureExtractor"), ("groupvit", "CLIPFeatureExtractor"), ("hubert", "Wav2Vec2FeatureExtractor"), - ("ijepa", "IJepaFeatureExtractor"), ("imagegpt", "ImageGPTFeatureExtractor"), ("layoutlmv2", "LayoutLMv2FeatureExtractor"), ("layoutlmv3", "LayoutLMv3FeatureExtractor"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index d072a1b3deb0..e01701076bb0 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -89,6 +89,7 @@ ("hiera", ("BitImageProcessor",)), ("idefics", ("IdeficsImageProcessor",)), ("idefics2", ("Idefics2ImageProcessor",)), + ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")), ("imagegpt", ("ImageGPTImageProcessor",)), ("instructblip", ("BlipImageProcessor",)), ("instructblipvideo", ("InstructBlipVideoImageProcessor",)), diff --git a/src/transformers/models/ijepa/__init__.py b/src/transformers/models/ijepa/__init__.py index 52ee746d156b..0fdf4806c789 100644 --- a/src/transformers/models/ijepa/__init__.py +++ b/src/transformers/models/ijepa/__init__.py @@ -17,20 +17,11 @@ OptionalDependencyNotAvailable, _LazyModule, is_torch_available, - is_torchvision_available, ) _import_structure = {"configuration_ijepa": ["IJepaConfig", "IJepaOnnxConfig"]} -try: - if not is_torchvision_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["image_processing_ijepa_fast"] = ["IJepaImageProcessorFast"] - try: if not is_torch_available(): raise OptionalDependencyNotAvailable() @@ -47,22 +38,16 @@ from .configuration_ijepa import IJepaConfig, IJepaOnnxConfig try: - if not is_torchvision_available(): + if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: pass else: - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_ijepa import ( - IJepaForImageClassification, - IJepaModel, - IJepaPreTrainedModel, - ) + from .modeling_ijepa import ( + IJepaForImageClassification, + IJepaModel, + IJepaPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/ijepa/configuration_ijepa.py b/src/transformers/models/ijepa/configuration_ijepa.py index 1bdcac4e25c3..a1d551db4bbd 100644 --- a/src/transformers/models/ijepa/configuration_ijepa.py +++ b/src/transformers/models/ijepa/configuration_ijepa.py @@ -66,8 +66,6 @@ class IJepaConfig(PretrainedConfig): The number of input channels. qkv_bias (`bool`, *optional*, defaults to `True`): Whether to add a bias to the queries, keys and values. - encoder_stride (`int`, *optional*, defaults to 16): - Factor to increase the spatial resolution by in the decoder head for masked image modeling. Example: @@ -101,7 +99,6 @@ def __init__( patch_size=16, num_channels=3, qkv_bias=True, - encoder_stride=16, **kwargs, ): super().__init__(**kwargs) @@ -119,7 +116,6 @@ def __init__( self.patch_size = patch_size self.num_channels = num_channels self.qkv_bias = qkv_bias - self.encoder_stride = encoder_stride class IJepaOnnxConfig(OnnxConfig): diff --git a/src/transformers/models/ijepa/convert_dino_to_pytorch.py b/src/transformers/models/ijepa/convert_dino_to_pytorch.py index 81bf417a0c0c..aeed304224c1 100644 --- a/src/transformers/models/ijepa/convert_dino_to_pytorch.py +++ b/src/transformers/models/ijepa/convert_dino_to_pytorch.py @@ -214,7 +214,7 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): model = IJepaModel(config, add_pooling_layer=False).eval() model.load_state_dict(state_dict) - # Check outputs on an image, prepared by IJepaImageProcessor + # Check outputs on an image, prepared by ViTImageProcessor image_processor = ViTImageProcessor() encoding = image_processor(images=prepare_img(), return_tensors="pt") pixel_values = encoding["pixel_values"] diff --git a/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py b/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py deleted file mode 100644 index 2ff83b964c1b..000000000000 --- a/src/transformers/models/ijepa/convert_ijepa_timm_to_pytorch.py +++ /dev/null @@ -1,319 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert I-JEPA and non-distilled DeiT checkpoints from the timm library.""" - -import argparse -from pathlib import Path - -import requests -import timm -import torch -from PIL import Image -from timm.data import ImageNetInfo, infer_imagenet_subset - -from transformers import ( - DeiTImageProcessor, - IJepaConfig, - IJepaForImageClassification, - IJepaImageProcessor, - IJepaModel, -) -from transformers.utils import logging - - -logging.set_verbosity_info() -logger = logging.get_logger(__name__) - - -# here we list all keys to be renamed (original name on the left, our name on the right) -def create_rename_keys(config, base_model=False): - rename_keys = [] - for i in range(config.num_hidden_layers): - # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms - rename_keys.append( - ( - f"blocks.{i}.norm1.weight", - f"ijepa.encoder.layer.{i}.layernorm_before.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.norm1.bias", - f"ijepa.encoder.layer.{i}.layernorm_before.bias", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.attn.proj.weight", - f"ijepa.encoder.layer.{i}.attention.output.dense.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.attn.proj.bias", - f"ijepa.encoder.layer.{i}.attention.output.dense.bias", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.norm2.weight", - f"ijepa.encoder.layer.{i}.layernorm_after.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.norm2.bias", - f"ijepa.encoder.layer.{i}.layernorm_after.bias", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc1.weight", - f"ijepa.encoder.layer.{i}.intermediate.dense.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc1.bias", - f"ijepa.encoder.layer.{i}.intermediate.dense.bias", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc2.weight", - f"ijepa.encoder.layer.{i}.output.dense.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc2.bias", - f"ijepa.encoder.layer.{i}.output.dense.bias", - ) - ) - - # projection layer + position embeddings - rename_keys.extend( - [ - ("cls_token", "ijepa.embeddings.cls_token"), - ( - "patch_embed.proj.weight", - "ijepa.embeddings.patch_embeddings.projection.weight", - ), - ( - "patch_embed.proj.bias", - "ijepa.embeddings.patch_embeddings.projection.bias", - ), - ("pos_embed", "ijepa.embeddings.position_embeddings"), - ] - ) - - if base_model: - # layernorm - rename_keys.extend( - [ - ("norm.weight", "layernorm.weight"), - ("norm.bias", "layernorm.bias"), - ] - ) - - # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" - rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("ijepa") else pair for pair in rename_keys] - else: - # layernorm + classification head - rename_keys.extend( - [ - ("norm.weight", "ijepa.layernorm.weight"), - ("norm.bias", "ijepa.layernorm.bias"), - ("head.weight", "classifier.weight"), - ("head.bias", "classifier.bias"), - ] - ) - - return rename_keys - - -# we split up the matrix of each encoder layer into queries, keys and values -def read_in_q_k_v(state_dict, config, base_model=False): - for i in range(config.num_hidden_layers): - if base_model: - prefix = "" - else: - prefix = "ijepa." - # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) - in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") - in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") - # next, add query, keys and values (in that order) to the state dict - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ - : config.hidden_size, : - ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ - config.hidden_size : config.hidden_size * 2, : - ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ - config.hidden_size : config.hidden_size * 2 - ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ - -config.hidden_size :, : - ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] - - -def remove_classification_head_(state_dict): - ignore_keys = ["head.weight", "head.bias"] - for k in ignore_keys: - state_dict.pop(k, None) - - -def rename_key(dct, old, new): - val = dct.pop(old) - dct[new] = val - - -# We will verify our results on an image of cute cats -def prepare_img(): - url = "http://images.cocodataset.org/val2017/000000039769.jpg" - im = Image.open(requests.get(url, stream=True).raw) - return im - - -@torch.no_grad() -def convert_ijepa_checkpoint(ijepa_name, pytorch_dump_folder_path): - """ - Copy/paste/tweak model's weights to our IJEPA structure. - """ - - # define default IJEPA configuration - config = IJepaConfig() - base_model = False - - # load original model from timm - timm_model = timm.create_model(ijepa_name, pretrained=True) - timm_model.eval() - - # detect unsupported IJEPA models in transformers - # fc_norm is present - if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity): - raise ValueError(f"{ijepa_name} is not supported in transformers because of the presence of fc_norm.") - - # use of global average pooling in combination (or without) class token - if getattr(timm_model, "global_pool", None) == "avg": - raise ValueError(f"{ijepa_name} is not supported in transformers because of use of global average pooling.") - - # CLIP style ijepa with norm_pre layer present - if "clip" in ijepa_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity): - raise ValueError( - f"{ijepa_name} is not supported in transformers because it's a CLIP style IJEPA with norm_pre layer." - ) - - # SigLIP style ijepa with attn_pool layer present - if "siglip" in ijepa_name and getattr(timm_model, "global_pool", None) == "map": - raise ValueError( - f"{ijepa_name} is not supported in transformers because it's a SigLIP style IJEPA with attn_pool." - ) - - # use of layer scale in IJEPA model blocks - if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance( - getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity - ): - raise ValueError(f"{ijepa_name} is not supported in transformers because it uses a layer scale in its blocks.") - - # Hybrid ResNet-IJEPAs - if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed): - raise ValueError(f"{ijepa_name} is not supported in transformers because it is a hybrid ResNet-IJEPA.") - - # get patch size and image size from the patch embedding submodule - config.patch_size = timm_model.patch_embed.patch_size[0] - config.image_size = timm_model.patch_embed.img_size[0] - - # retrieve architecture-specific parameters from the timm model - config.hidden_size = timm_model.embed_dim - config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features - config.num_hidden_layers = len(timm_model.blocks) - config.num_attention_heads = timm_model.blocks[0].attn.num_heads - - # check whether the model has a classification head or not - if timm_model.num_classes != 0: - config.num_labels = timm_model.num_classes - # infer ImageNet subset from timm model - imagenet_subset = infer_imagenet_subset(timm_model) - dataset_info = ImageNetInfo(imagenet_subset) - config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())} - config.label2id = {v: k for k, v in config.id2label.items()} - else: - print(f"{ijepa_name} is going to be converted as a feature extractor only.") - base_model = True - - # load state_dict of original model - state_dict = timm_model.state_dict() - - # remove and rename some keys in the state dict - if base_model: - remove_classification_head_(state_dict) - rename_keys = create_rename_keys(config, base_model) - for src, dest in rename_keys: - rename_key(state_dict, src, dest) - read_in_q_k_v(state_dict, config, base_model) - - # load HuggingFace model - if base_model: - model = IJepaModel(config, add_pooling_layer=False).eval() - else: - model = IJepaForImageClassification(config).eval() - model.load_state_dict(state_dict) - - # Check outputs on an image, prepared by IJepaImageProcessor/DeiTImageProcessor - if "deit" in ijepa_name: - image_processor = DeiTImageProcessor(size=config.image_size) - else: - image_processor = IJepaImageProcessor(size=config.image_size) - encoding = image_processor(images=prepare_img(), return_tensors="pt") - pixel_values = encoding["pixel_values"] - outputs = model(pixel_values) - - if base_model: - timm_pooled_output = timm_model.forward_features(pixel_values) - assert timm_pooled_output.shape == outputs.last_hidden_state.shape - assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1) - else: - timm_logits = timm_model(pixel_values) - assert timm_logits.shape == outputs.logits.shape - assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) - - Path(pytorch_dump_folder_path).mkdir(exist_ok=True) - print(f"Saving model {ijepa_name} to {pytorch_dump_folder_path}") - model.save_pretrained(pytorch_dump_folder_path) - print(f"Saving image processor to {pytorch_dump_folder_path}") - image_processor.save_pretrained(pytorch_dump_folder_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--ijepa_name", - default="ijepa_base_patch16_224", - type=str, - help="Name of the IJEPA timm model you'd like to convert.", - ) - parser.add_argument( - "--pytorch_dump_folder_path", - default=None, - type=str, - help="Path to the output PyTorch model directory.", - ) - - args = parser.parse_args() - convert_ijepa_checkpoint(args.ijepa_name, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 1395bb0d5f68..cc64f25c38a2 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -510,14 +510,22 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No config ([`IJepaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + + add_pooling_layer (bool, *optional*, defaults to `True`): Whether to include a pooling layer in the model. If set to `True`, the model will include a pooling layer that can be used to extract a pooled output representation of the hidden states. If set to `False`, the pooling layer will be omitted. + + use_mask_token (bool, *optional*, defaults to `False`): Whether to use a mask token in the embeddings layer. If set to `True`, a special token will be used to mask certain inputs during training. If set to `False`, no mask token will be used, and the embeddings layer will function without masking capability. + """ IJEPA_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`IJepaImageProcessor.__call__`] + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details. + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: @@ -581,10 +589,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -650,13 +654,25 @@ def forward(self, hidden_states): return pooled_output +IJEPA_FOR_IMAGE_CLASSIFICATION_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`IJepaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + @add_start_docstrings( """ I-JEPA Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the average pooling of the output) e.g. for ImageNet. """, - IJEPA_START_DOCSTRING, + IJEPA_FOR_IMAGE_CLASSIFICATION_START_DOCSTRING, ) class IJepaForImageClassification(IJepaPreTrainedModel): def __init__(self, config: IJepaConfig) -> None: diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index cb8f0e8ff047..0f2390fb694b 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -842,27 +842,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxIJepaForImageClassification(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxIJepaModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxIJepaPreTrainedModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - class FlaxLlamaForCausalLM(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 7663a6d652d5..942a7afced4b 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1563,27 +1563,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) -class TFIJepaForImageClassification(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFIJepaModel(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFIJepaPreTrainedModel(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - class TFLayoutLMForMaskedLM(metaclass=DummyObject): _backends = ["tf"] diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 38fbbf313c2b..1d532aeea2a4 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -9,13 +9,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) -class IJepaImageProcessorFast(metaclass=DummyObject): - _backends = ["torchvision"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torchvision"]) - - class ViTImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index ffbae4164329..19f8dc1b1d9c 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -296,13 +296,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) -class IJepaImageProcessor(metaclass=DummyObject): - _backends = ["vision"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["vision"]) - - class ImageGPTFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] From 66773eee3a0a144a28691e2f8b2e3c65764b39db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 19:44:27 +0200 Subject: [PATCH 11/44] rename conversion script --- .../ijepa/{convert_dino_to_pytorch.py => convert_ijepa_to_hf.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/transformers/models/ijepa/{convert_dino_to_pytorch.py => convert_ijepa_to_hf.py} (100%) diff --git a/src/transformers/models/ijepa/convert_dino_to_pytorch.py b/src/transformers/models/ijepa/convert_ijepa_to_hf.py similarity index 100% rename from src/transformers/models/ijepa/convert_dino_to_pytorch.py rename to src/transformers/models/ijepa/convert_ijepa_to_hf.py From 9b7e8b448a6391ba144ba9a1ec60518988e71d0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sun, 25 Aug 2024 19:50:29 +0200 Subject: [PATCH 12/44] Add I-JEPA to sdpa docs --- docs/source/en/perf_infer_gpu_one.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index df1e64e36877..2376b2181a5a 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -214,6 +214,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) +* [I-JEPA](https://huggingface.co/docs/transformers/model_doc/ijepa#transformers.IJepaModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) From 40cf5284175c1688f4fe5a0d5df2a9acff1e4bbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Wed, 28 Aug 2024 10:22:03 +0200 Subject: [PATCH 13/44] minor fixes --- docs/source/en/model_doc/ijepa.md | 17 +++----- src/transformers/models/ijepa/__init__.py | 7 +-- .../models/ijepa/configuration_ijepa.py | 29 ------------- .../models/ijepa/modeling_ijepa.py | 43 +++++++++++++++---- src/transformers/utils/fx.py | 1 + tests/models/ijepa/test_modeling_ijepa.py | 2 +- 6 files changed, 42 insertions(+), 57 deletions(-) diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md index 2ce748fb39f9..c966fe6b02aa 100644 --- a/docs/source/en/model_doc/ijepa.md +++ b/docs/source/en/model_doc/ijepa.md @@ -18,19 +18,15 @@ rendered properly in your Markdown viewer. ## Overview -The I-JEPA model was proposed in []() by . - +The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/pdf/2301.08243.pdf) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas. +I-JEPA is a self-supervised learning method that predicts the representations of one part of an image based on other parts of the same image. This approach focuses on learning semantic features without relying on pre-defined invariances from hand-crafted data transformations, which can bias specific tasks, or on filling in pixel-level details, which often leads to less meaningful representations. The abstract from the paper is the following: -** +This paper demonstrates an approach for learning highly semantic image representations without relying on hand-crafted data-augmentations. We introduce the Image- based Joint-Embedding Predictive Architecture (I-JEPA), a non-generative approach for self-supervised learning from images. The idea behind I-JEPA is simple: from a single context block, predict the representations of various target blocks in the same image. A core design choice to guide I-JEPA towards producing semantic representations is the masking strategy; specifically, it is crucial to (a) sample tar- get blocks with sufficiently large scale (semantic), and to (b) use a sufficiently informative (spatially distributed) context block. Empirically, when combined with Vision Transform- ers, we find I-JEPA to be highly scalable. For instance, we train a ViT-Huge/14 on ImageNet using 16 A100 GPUs in under 72 hours to achieve strong downstream performance across a wide range of tasks, from linear classification to object counting and depth prediction. -Tips: - - - -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +This model was contributed by [jmtzt](https://huggingface.co/jmtzt). +The original code can be found [here](https://github.com/facebookresearch/ijepa). ## IJepaConfig @@ -46,6 +42,3 @@ The original code can be found [here](). [[autodoc]] IJepaForImageClassification - forward - - - diff --git a/src/transformers/models/ijepa/__init__.py b/src/transformers/models/ijepa/__init__.py index 0fdf4806c789..50ab72784d98 100644 --- a/src/transformers/models/ijepa/__init__.py +++ b/src/transformers/models/ijepa/__init__.py @@ -52,9 +52,4 @@ else: import sys - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/ijepa/configuration_ijepa.py b/src/transformers/models/ijepa/configuration_ijepa.py index a1d551db4bbd..26378e6e81d9 100644 --- a/src/transformers/models/ijepa/configuration_ijepa.py +++ b/src/transformers/models/ijepa/configuration_ijepa.py @@ -14,17 +14,7 @@ # limitations under the License. """I-JEPA model configuration""" -from collections import OrderedDict -from typing import Mapping - -from packaging import version - from ...configuration_utils import PretrainedConfig -from ...onnx import OnnxConfig -from ...utils import logging - - -logger = logging.get_logger(__name__) class IJepaConfig(PretrainedConfig): @@ -116,22 +106,3 @@ def __init__( self.patch_size = patch_size self.num_channels = num_channels self.qkv_bias = qkv_bias - - -class IJepaOnnxConfig(OnnxConfig): - torch_onnx_minimum_version = version.parse("1.11") - - @property - def inputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict( - [ - ( - "pixel_values", - {0: "batch", 1: "num_channels", 2: "height", 3: "width"}, - ), - ] - ) - - @property - def atol_for_validation(self) -> float: - return 1e-4 diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index cc64f25c38a2..1677564b97ca 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768] # Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224" +_IMAGE_CLASS_CHECKPOINT = "jmtzt/ijepa_huge_patch14_1k" _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" @@ -535,9 +535,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -665,6 +667,35 @@ def forward(self, hidden_states): configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ +IJEPA_FOR_IMAGE_CLASSIFICATION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + @add_start_docstrings( """ @@ -687,7 +718,7 @@ def __init__(self, config: IJepaConfig) -> None: # Initialize weights and apply final processing self.post_init() - @add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(IJEPA_FOR_IMAGE_CLASSIFICATION_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_IMAGE_CLASS_CHECKPOINT, output_type=ImageClassifierOutput, @@ -703,12 +734,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ijepa( diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index c78b4c34c331..c5e98b8edef6 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -140,6 +140,7 @@ def _generate_supported_model_class_names( "gptj", "hiera", "hubert", + "ijepa", "layoutlm", "llama", "cohere", diff --git a/tests/models/ijepa/test_modeling_ijepa.py b/tests/models/ijepa/test_modeling_ijepa.py index 371db004c62e..bde91cff7b6e 100644 --- a/tests/models/ijepa/test_modeling_ijepa.py +++ b/tests/models/ijepa/test_modeling_ijepa.py @@ -202,7 +202,7 @@ class IJepaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): if is_torch_available() else {} ) - fx_compatible = False + fx_compatible = True test_pruning = False test_resize_embeddings = False From 2bae64ace91f7f3b1ba17fa62b6637cbd52f0d7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Wed, 28 Aug 2024 10:22:40 +0200 Subject: [PATCH 14/44] adjust conversion script --- .../models/ijepa/convert_ijepa_to_hf.py | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/ijepa/convert_ijepa_to_hf.py b/src/transformers/models/ijepa/convert_ijepa_to_hf.py index aeed304224c1..de60883fe7e4 100644 --- a/src/transformers/models/ijepa/convert_ijepa_to_hf.py +++ b/src/transformers/models/ijepa/convert_ijepa_to_hf.py @@ -169,6 +169,8 @@ def get_ijepa_config(model_name): config.layer_norm_eps = 1e-6 config.mlp_ratio = 4 config.intermediate_size = 5120 + if model_name == "ijepa_vith16_1k": + config.image_size = 448 elif "vitg" in model_name: config.hidden_size = 1408 config.num_hidden_layers = 40 @@ -182,7 +184,7 @@ def get_ijepa_config(model_name): @torch.no_grad() -def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): +def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits): """ Copy/paste/tweak model's weights to our IJEPA structure. """ @@ -214,28 +216,28 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): model = IJepaModel(config, add_pooling_layer=False).eval() model.load_state_dict(state_dict) - # Check outputs on an image, prepared by ViTImageProcessor - image_processor = ViTImageProcessor() - encoding = image_processor(images=prepare_img(), return_tensors="pt") - pixel_values = encoding["pixel_values"] - outputs = model(pixel_values) + if verify_logits: + # Check outputs on an image, prepared by ViTImageProcessor + image_processor = ViTImageProcessor() + encoding = image_processor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + with torch.no_grad(): + outputs = model(pixel_values) - expected_slice = torch.Tensor( - [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]] - ) + expected_slice = torch.Tensor( + [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]] + ) - assert torch.allclose( - expected_slice, - outputs.last_hidden_state[0, :3, :3], - atol=1e-4, - ) + assert torch.allclose( + expected_slice, + outputs.last_hidden_state[0, :3, :3], + atol=1e-4, + ) if pytorch_dump_folder_path is not None: Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model {model_name} to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) - print(f"Saving image processor to {pytorch_dump_folder_path}") - image_processor.save_pretrained(pytorch_dump_folder_path) if push_to_hub: model_name_to_hf_name = { @@ -272,9 +274,12 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): parser.add_argument( "--push_to_hub", action="store_true", - help="Whether or not to push the model to the Hub.", + help="Whether or not to push the model to the 🤗 Hub.", + ) + parser.add_argument( + "--verify_logits", action="store_false", help="Whether or not to verify logits after conversion." ) parser.set_defaults() args = parser.parse_args() - convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) + convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits) \ No newline at end of file From 4ccf28cdb7d8e3873e1b62d40d02363a9e7ccf16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Wed, 28 Aug 2024 10:25:55 +0200 Subject: [PATCH 15/44] update conversion script --- src/transformers/models/ijepa/convert_ijepa_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/ijepa/convert_ijepa_to_hf.py b/src/transformers/models/ijepa/convert_ijepa_to_hf.py index de60883fe7e4..0546aae478e1 100644 --- a/src/transformers/models/ijepa/convert_ijepa_to_hf.py +++ b/src/transformers/models/ijepa/convert_ijepa_to_hf.py @@ -282,4 +282,4 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, parser.set_defaults() args = parser.parse_args() - convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits) \ No newline at end of file + convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits) From 851ed7e733f02bed826ba53f0d9062668ad74fe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Wed, 28 Aug 2024 10:29:07 +0200 Subject: [PATCH 16/44] adjust sdpa docs --- docs/source/en/perf_infer_gpu_one.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 8dd970760f85..b00da75b72c8 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -212,9 +212,12 @@ For now, Transformers supports SDPA inference and training for the following arc * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) +* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) +* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) +* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [I-JEPA](https://huggingface.co/docs/transformers/model_doc/ijepa#transformers.IJepaModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) From b7a027cc508309250bc8ccbc04d656c05484288e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Wed, 28 Aug 2024 10:38:57 +0200 Subject: [PATCH 17/44] [run_slow] ijepa From 552e800bcee0c57810cf95bfa9dc5178cf5c8139 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Fri, 30 Aug 2024 20:28:09 +0200 Subject: [PATCH 18/44] [run-slow] ijepa --- src/transformers/models/ijepa/modeling_ijepa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 1677564b97ca..0a93255c34ef 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -109,7 +109,6 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: mode="bicubic", align_corners=False, ) - assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) From f2f7eb8dd10e94bbfbc33c28794c073806c1345b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sat, 31 Aug 2024 10:20:54 +0200 Subject: [PATCH 19/44] [run-slow] ijepa From f24ef1262ac1987284b2e79b1ced962ac7172537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 2 Sep 2024 19:48:19 +0200 Subject: [PATCH 20/44] [run-slow] ijepa From 6f9acc9250af1e05324414babc91abb2c7be52b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 2 Sep 2024 20:36:20 +0200 Subject: [PATCH 21/44] [run-slow] ijepa From d663ea3b5adeb03f8934526af1b36a7a1da33d76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 2 Sep 2024 20:44:58 +0200 Subject: [PATCH 22/44] [run-slow] ijepa From 7da705b36557a5b77b43439307e46929fd012888 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sat, 16 Nov 2024 19:46:48 +0100 Subject: [PATCH 23/44] formatting issues --- src/transformers/__init__.py | 311 +++++------------- .../models/auto/configuration_auto.py | 77 ++--- .../models/auto/image_processing_auto.py | 137 ++------ src/transformers/models/auto/modeling_auto.py | 149 +++------ src/transformers/utils/dummy_pt_objects.py | 4 +- 5 files changed, 170 insertions(+), 508 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d5be2cfe79dc..e62093dd7cfb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -48,6 +48,7 @@ logging, ) + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -371,9 +372,7 @@ "Speech2Text2Tokenizer", ], "models.deprecated.tapex": ["TapexTokenizer"], - "models.deprecated.trajectory_transformer": [ - "TrajectoryTransformerConfig" - ], + "models.deprecated.trajectory_transformer": ["TrajectoryTransformerConfig"], "models.deprecated.transfo_xl": [ "TransfoXLConfig", "TransfoXLCorpus", @@ -1000,26 +999,20 @@ from .utils import dummy_sentencepiece_objects _import_structure["utils.dummy_sentencepiece_objects"] = [ - name - for name in dir(dummy_sentencepiece_objects) - if not name.startswith("_") + name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_") ] else: _import_structure["models.albert"].append("AlbertTokenizer") _import_structure["models.barthez"].append("BarthezTokenizer") _import_structure["models.bartpho"].append("BartphoTokenizer") - _import_structure["models.bert_generation"].append( - "BertGenerationTokenizer" - ) + _import_structure["models.bert_generation"].append("BertGenerationTokenizer") _import_structure["models.big_bird"].append("BigBirdTokenizer") _import_structure["models.camembert"].append("CamembertTokenizer") _import_structure["models.code_llama"].append("CodeLlamaTokenizer") _import_structure["models.cpm"].append("CpmTokenizer") _import_structure["models.deberta_v2"].append("DebertaV2Tokenizer") _import_structure["models.deprecated.ernie_m"].append("ErnieMTokenizer") - _import_structure["models.deprecated.xlm_prophetnet"].append( - "XLMProphetNetTokenizer" - ) + _import_structure["models.deprecated.xlm_prophetnet"].append("XLMProphetNetTokenizer") _import_structure["models.fnet"].append("FNetTokenizer") _import_structure["models.gemma"].append("GemmaTokenizer") _import_structure["models.gpt_sw3"].append("GPTSw3Tokenizer") @@ -1054,9 +1047,7 @@ from .utils import dummy_tokenizers_objects _import_structure["utils.dummy_tokenizers_objects"] = [ - name - for name in dir(dummy_tokenizers_objects) - if not name.startswith("_") + name for name in dir(dummy_tokenizers_objects) if not name.startswith("_") ] else: # Fast tokenizers structure @@ -1066,9 +1057,7 @@ _import_structure["models.bert"].append("BertTokenizerFast") _import_structure["models.big_bird"].append("BigBirdTokenizerFast") _import_structure["models.blenderbot"].append("BlenderbotTokenizerFast") - _import_structure["models.blenderbot_small"].append( - "BlenderbotSmallTokenizerFast" - ) + _import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast") _import_structure["models.bloom"].append("BloomTokenizerFast") _import_structure["models.camembert"].append("CamembertTokenizerFast") _import_structure["models.clip"].append("CLIPTokenizerFast") @@ -1080,9 +1069,7 @@ _import_structure["models.deberta"].append("DebertaTokenizerFast") _import_structure["models.deberta_v2"].append("DebertaV2TokenizerFast") _import_structure["models.deprecated.realm"].append("RealmTokenizerFast") - _import_structure["models.deprecated.retribert"].append( - "RetriBertTokenizerFast" - ) + _import_structure["models.deprecated.retribert"].append("RetriBertTokenizerFast") _import_structure["models.distilbert"].append("DistilBertTokenizerFast") _import_structure["models.dpr"].extend( [ @@ -1097,9 +1084,7 @@ _import_structure["models.gemma"].append("GemmaTokenizerFast") _import_structure["models.gpt2"].append("GPT2TokenizerFast") _import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast") - _import_structure["models.gpt_neox_japanese"].append( - "GPTNeoXJapaneseTokenizer" - ) + _import_structure["models.gpt_neox_japanese"].append("GPTNeoXJapaneseTokenizer") _import_structure["models.herbert"].append("HerbertTokenizerFast") _import_structure["models.layoutlm"].append("LayoutLMTokenizerFast") _import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast") @@ -1144,9 +1129,7 @@ from .utils import dummy_sentencepiece_and_tokenizers_objects _import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [ - name - for name in dir(dummy_sentencepiece_and_tokenizers_objects) - if not name.startswith("_") + name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_") ] else: _import_structure["convert_slow_tokenizer"] = [ @@ -1162,9 +1145,7 @@ from .utils import dummy_tensorflow_text_objects _import_structure["utils.dummy_tensorflow_text_objects"] = [ - name - for name in dir(dummy_tensorflow_text_objects) - if not name.startswith("_") + name for name in dir(dummy_tensorflow_text_objects) if not name.startswith("_") ] else: _import_structure["models.bert"].append("TFBertTokenizer") @@ -1177,9 +1158,7 @@ from .utils import dummy_keras_nlp_objects _import_structure["utils.dummy_keras_nlp_objects"] = [ - name - for name in dir(dummy_keras_nlp_objects) - if not name.startswith("_") + name for name in dir(dummy_keras_nlp_objects) if not name.startswith("_") ] else: _import_structure["models.gpt2"].append("TFGPT2Tokenizer") @@ -1198,39 +1177,25 @@ _import_structure["image_processing_base"] = ["ImageProcessingMixin"] _import_structure["image_processing_utils"] = ["BaseImageProcessor"] _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] - _import_structure["models.beit"].extend( - ["BeitFeatureExtractor", "BeitImageProcessor"] - ) + _import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"]) _import_structure["models.bit"].extend(["BitImageProcessor"]) _import_structure["models.blip"].extend(["BlipImageProcessor"]) _import_structure["models.bridgetower"].append("BridgeTowerImageProcessor") _import_structure["models.chameleon"].append("ChameleonImageProcessor") - _import_structure["models.chinese_clip"].extend( - ["ChineseCLIPFeatureExtractor", "ChineseCLIPImageProcessor"] - ) - _import_structure["models.clip"].extend( - ["CLIPFeatureExtractor", "CLIPImageProcessor"] - ) + _import_structure["models.chinese_clip"].extend(["ChineseCLIPFeatureExtractor", "ChineseCLIPImageProcessor"]) + _import_structure["models.clip"].extend(["CLIPFeatureExtractor", "CLIPImageProcessor"]) _import_structure["models.conditional_detr"].extend( ["ConditionalDetrFeatureExtractor", "ConditionalDetrImageProcessor"] ) - _import_structure["models.convnext"].extend( - ["ConvNextFeatureExtractor", "ConvNextImageProcessor"] - ) + _import_structure["models.convnext"].extend(["ConvNextFeatureExtractor", "ConvNextImageProcessor"]) _import_structure["models.deformable_detr"].extend( ["DeformableDetrFeatureExtractor", "DeformableDetrImageProcessor"] ) - _import_structure["models.deit"].extend( - ["DeiTFeatureExtractor", "DeiTImageProcessor"] - ) + _import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"]) _import_structure["models.deprecated.deta"].append("DetaImageProcessor") - _import_structure["models.deprecated.efficientformer"].append( - "EfficientFormerImageProcessor" - ) + _import_structure["models.deprecated.efficientformer"].append("EfficientFormerImageProcessor") _import_structure["models.deprecated.tvlt"].append("TvltImageProcessor") - _import_structure["models.deprecated.vit_hybrid"].extend( - ["ViTHybridImageProcessor"] - ) + _import_structure["models.deprecated.vit_hybrid"].extend(["ViTHybridImageProcessor"]) _import_structure["models.detr"].extend( [ "DetrFeatureExtractor", @@ -1238,109 +1203,57 @@ "DetrImageProcessorFast", ] ) - _import_structure["models.donut"].extend( - ["DonutFeatureExtractor", "DonutImageProcessor"] - ) - _import_structure["models.dpt"].extend( - ["DPTFeatureExtractor", "DPTImageProcessor"] - ) - _import_structure["models.efficientnet"].append( - "EfficientNetImageProcessor" - ) - _import_structure["models.flava"].extend( - ["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"] - ) - _import_structure["models.fuyu"].extend( - ["FuyuImageProcessor", "FuyuProcessor"] - ) - _import_structure["models.glpn"].extend( - ["GLPNFeatureExtractor", "GLPNImageProcessor"] - ) - _import_structure["models.grounding_dino"].extend( - ["GroundingDinoImageProcessor"] - ) + _import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"]) + _import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"]) + _import_structure["models.efficientnet"].append("EfficientNetImageProcessor") + _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"]) + _import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"]) + _import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"]) + _import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"]) _import_structure["models.idefics"].extend(["IdeficsImageProcessor"]) _import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"]) _import_structure["models.idefics3"].extend(["Idefics3ImageProcessor"]) - _import_structure["models.imagegpt"].extend( - ["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"] - ) - _import_structure["models.instructblipvideo"].extend( - ["InstructBlipVideoImageProcessor"] - ) - _import_structure["models.layoutlmv2"].extend( - ["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"] - ) - _import_structure["models.layoutlmv3"].extend( - ["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"] - ) - _import_structure["models.levit"].extend( - ["LevitFeatureExtractor", "LevitImageProcessor"] - ) + _import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"]) + _import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"]) + _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"]) + _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"]) + _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"]) _import_structure["models.llava_next"].append("LlavaNextImageProcessor") - _import_structure["models.llava_next_video"].append( - "LlavaNextVideoImageProcessor" - ) + _import_structure["models.llava_next_video"].append("LlavaNextVideoImageProcessor") _import_structure["models.llava_onevision"].extend( ["LlavaOnevisionImageProcessor", "LlavaOnevisionVideoProcessor"] ) _import_structure["models.mask2former"].append("Mask2FormerImageProcessor") - _import_structure["models.maskformer"].extend( - ["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"] - ) + _import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"]) _import_structure["models.mllama"].extend(["MllamaImageProcessor"]) - _import_structure["models.mobilenet_v1"].extend( - ["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"] - ) - _import_structure["models.mobilenet_v2"].extend( - ["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"] - ) - _import_structure["models.mobilevit"].extend( - ["MobileViTFeatureExtractor", "MobileViTImageProcessor"] - ) + _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) + _import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"]) + _import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"]) _import_structure["models.nougat"].append("NougatImageProcessor") _import_structure["models.oneformer"].extend(["OneFormerImageProcessor"]) _import_structure["models.owlv2"].append("Owlv2ImageProcessor") - _import_structure["models.owlvit"].extend( - ["OwlViTFeatureExtractor", "OwlViTImageProcessor"] - ) - _import_structure["models.perceiver"].extend( - ["PerceiverFeatureExtractor", "PerceiverImageProcessor"] - ) + _import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"]) + _import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"]) _import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"]) _import_structure["models.pixtral"].append("PixtralImageProcessor") - _import_structure["models.poolformer"].extend( - ["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"] - ) + _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"]) _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) - _import_structure["models.rt_detr"].extend( - ["RTDetrImageProcessor", "RTDetrImageProcessorFast"] - ) + _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor", "RTDetrImageProcessorFast"]) _import_structure["models.sam"].extend(["SamImageProcessor"]) - _import_structure["models.segformer"].extend( - ["SegformerFeatureExtractor", "SegformerImageProcessor"] - ) + _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) _import_structure["models.seggpt"].extend(["SegGptImageProcessor"]) _import_structure["models.siglip"].append("SiglipImageProcessor") _import_structure["models.superpoint"].extend(["SuperPointImageProcessor"]) _import_structure["models.swin2sr"].append("Swin2SRImageProcessor") _import_structure["models.tvp"].append("TvpImageProcessor") _import_structure["models.video_llava"].append("VideoLlavaImageProcessor") - _import_structure["models.videomae"].extend( - ["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"] - ) - _import_structure["models.vilt"].extend( - ["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"] - ) - _import_structure["models.vit"].extend( - ["ViTFeatureExtractor", "ViTImageProcessor"] - ) + _import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"]) + _import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"]) + _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"]) _import_structure["models.vitmatte"].append("VitMatteImageProcessor") _import_structure["models.vivit"].append("VivitImageProcessor") - _import_structure["models.yolos"].extend( - ["YolosFeatureExtractor", "YolosImageProcessor"] - ) + _import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"]) _import_structure["models.zoedepth"].append("ZoeDepthImageProcessor") try: @@ -1350,14 +1263,10 @@ from .utils import dummy_torchvision_objects _import_structure["utils.dummy_torchvision_objects"] = [ - name - for name in dir(dummy_torchvision_objects) - if not name.startswith("_") + name for name in dir(dummy_torchvision_objects) if not name.startswith("_") ] else: - _import_structure["image_processing_utils_fast"] = [ - "BaseImageProcessorFast" - ] + _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] _import_structure["models.vit"].append("ViTImageProcessorFast") # PyTorch-backed objects @@ -1367,15 +1276,11 @@ except OptionalDependencyNotAvailable: from .utils import dummy_pt_objects - _import_structure["utils.dummy_pt_objects"] = [ - name for name in dir(dummy_pt_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: _import_structure["activations"] = [] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] - _import_structure["benchmark.benchmark_args"] = [ - "PyTorchBenchmarkArguments" - ] + _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] _import_structure["cache_utils"] = [ "Cache", "CacheConfig", @@ -1856,9 +1761,7 @@ "CodeGenPreTrainedModel", ] ) - _import_structure["models.cohere"].extend( - ["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"] - ) + _import_structure["models.cohere"].extend(["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"]) _import_structure["models.conditional_detr"].extend( [ "ConditionalDetrForObjectDetection", @@ -2064,9 +1967,7 @@ "MegaPreTrainedModel", ] ) - _import_structure["models.deprecated.mmbt"].extend( - ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"] - ) + _import_structure["models.deprecated.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]) _import_structure["models.deprecated.nat"].extend( [ "NatBackbone", @@ -2368,9 +2269,7 @@ "FocalNetPreTrainedModel", ] ) - _import_structure["models.fsmt"].extend( - ["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"] - ) + _import_structure["models.fsmt"].extend(["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"]) _import_structure["models.funnel"].extend( [ "FunnelBaseModel", @@ -2385,9 +2284,7 @@ "load_tf_weights_in_funnel", ] ) - _import_structure["models.fuyu"].extend( - ["FuyuForCausalLM", "FuyuPreTrainedModel"] - ) + _import_structure["models.fuyu"].extend(["FuyuForCausalLM", "FuyuPreTrainedModel"]) _import_structure["models.gemma"].extend( [ "GemmaForCausalLM", @@ -3199,9 +3096,7 @@ "Pix2StructVisionModel", ] ) - _import_structure["models.pixtral"].extend( - ["PixtralPreTrainedModel", "PixtralVisionModel"] - ) + _import_structure["models.pixtral"].extend(["PixtralPreTrainedModel", "PixtralVisionModel"]) _import_structure["models.plbart"].extend( [ "PLBartForCausalLM", @@ -3474,9 +3369,7 @@ "SiglipVisionModel", ] ) - _import_structure["models.speech_encoder_decoder"].extend( - ["SpeechEncoderDecoderModel"] - ) + _import_structure["models.speech_encoder_decoder"].extend(["SpeechEncoderDecoderModel"]) _import_structure["models.speech_to_text"].extend( [ "Speech2TextForConditionalGeneration", @@ -3718,12 +3611,8 @@ "VipLlavaPreTrainedModel", ] ) - _import_structure["models.vision_encoder_decoder"].extend( - ["VisionEncoderDecoderModel"] - ) - _import_structure["models.vision_text_dual_encoder"].extend( - ["VisionTextDualEncoderModel"] - ) + _import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"]) + _import_structure["models.vision_text_dual_encoder"].extend(["VisionTextDualEncoderModel"]) _import_structure["models.visual_bert"].extend( [ "VisualBertForMultipleChoice", @@ -3974,14 +3863,10 @@ except OptionalDependencyNotAvailable: from .utils import dummy_tf_objects - _import_structure["utils.dummy_tf_objects"] = [ - name for name in dir(dummy_tf_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")] else: _import_structure["activations_tf"] = [] - _import_structure["benchmark.benchmark_args_tf"] = [ - "TensorFlowBenchmarkArguments" - ] + _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"] _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"] _import_structure["generation"].extend( [ @@ -4371,9 +4256,7 @@ "TFLayoutLMv3PreTrainedModel", ] ) - _import_structure["models.led"].extend( - ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"] - ) + _import_structure["models.led"].extend(["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"]) _import_structure["models.longformer"].extend( [ "TFLongformerForMaskedLM", @@ -4394,9 +4277,7 @@ "TFLxmertVisualFeatureEncoder", ] ) - _import_structure["models.marian"].extend( - ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"] - ) + _import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]) _import_structure["models.mbart"].extend( [ "TFMBartForConditionalGeneration", @@ -4446,9 +4327,7 @@ "TFMPNetPreTrainedModel", ] ) - _import_structure["models.mt5"].extend( - ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"] - ) + _import_structure["models.mt5"].extend(["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"]) _import_structure["models.openai"].extend( [ "TFOpenAIGPTDoubleHeadsModel", @@ -4599,12 +4478,8 @@ "TFTapasPreTrainedModel", ] ) - _import_structure["models.vision_encoder_decoder"].extend( - ["TFVisionEncoderDecoderModel"] - ) - _import_structure["models.vision_text_dual_encoder"].extend( - ["TFVisionTextDualEncoderModel"] - ) + _import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"]) + _import_structure["models.vision_text_dual_encoder"].extend(["TFVisionTextDualEncoderModel"]) _import_structure["models.vit"].extend( [ "TFViTForImageClassification", @@ -4700,13 +4575,9 @@ dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects, ) - _import_structure[ - "utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects" - ] = [ + _import_structure["utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects"] = [ name - for name in dir( - dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects - ) + for name in dir(dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects) if not name.startswith("_") ] else: @@ -4721,17 +4592,11 @@ from .utils import dummy_torchaudio_objects _import_structure["utils.dummy_torchaudio_objects"] = [ - name - for name in dir(dummy_torchaudio_objects) - if not name.startswith("_") + name for name in dir(dummy_torchaudio_objects) if not name.startswith("_") ] else: - _import_structure["models.musicgen_melody"].append( - "MusicgenMelodyFeatureExtractor" - ) - _import_structure["models.musicgen_melody"].append( - "MusicgenMelodyProcessor" - ) + _import_structure["models.musicgen_melody"].append("MusicgenMelodyFeatureExtractor") + _import_structure["models.musicgen_melody"].append("MusicgenMelodyProcessor") # FLAX-backed objects @@ -4921,12 +4786,8 @@ "FlaxElectraPreTrainedModel", ] ) - _import_structure["models.encoder_decoder"].append( - "FlaxEncoderDecoderModel" - ) - _import_structure["models.gpt2"].extend( - ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"] - ) + _import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel") + _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]) _import_structure["models.gpt_neo"].extend( [ "FlaxGPTNeoForCausalLM", @@ -4934,15 +4795,9 @@ "FlaxGPTNeoPreTrainedModel", ] ) - _import_structure["models.gptj"].extend( - ["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"] - ) - _import_structure["models.llama"].extend( - ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"] - ) - _import_structure["models.gemma"].extend( - ["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"] - ) + _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) + _import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]) + _import_structure["models.gemma"].extend(["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"]) _import_structure["models.longt5"].extend( [ "FlaxLongT5ForConditionalGeneration", @@ -5043,9 +4898,7 @@ "FlaxRoFormerPreTrainedModel", ] ) - _import_structure["models.speech_encoder_decoder"].append( - "FlaxSpeechEncoderDecoderModel" - ) + _import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel") _import_structure["models.t5"].extend( [ "FlaxT5EncoderModel", @@ -5054,12 +4907,8 @@ "FlaxT5PreTrainedModel", ] ) - _import_structure["models.vision_encoder_decoder"].append( - "FlaxVisionEncoderDecoderModel" - ) - _import_structure["models.vision_text_dual_encoder"].extend( - ["FlaxVisionTextDualEncoderModel"] - ) + _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") + _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vit"].extend( [ "FlaxViTForImageClassification", @@ -9136,11 +8985,7 @@ ) -if ( - not is_tf_available() - and not is_torch_available() - and not is_flax_available() -): +if not is_tf_available() and not is_torch_available() and not is_flax_available(): logger.warning_advice( "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. " "Models won't be available and only tokenizers, configuration " diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2f4e37f0e6cd..9ec932fe7eec 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -28,6 +28,7 @@ ) from ...utils import CONFIG_NAME, logging + logger = logging.get_logger(__name__) @@ -741,9 +742,7 @@ def __getitem__(self, key): value = self._mapping[key] module_name = model_type_to_module_name(key) if module_name not in self._modules: - self._modules[module_name] = importlib.import_module( - f".{module_name}", "transformers.models" - ) + self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value) @@ -756,19 +755,13 @@ def keys(self): return list(self._mapping.keys()) + list(self._extra_content.keys()) def values(self): - return [self[k] for k in self._mapping.keys()] + list( - self._extra_content.values() - ) + return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()) def items(self): - return [(k, self[k]) for k in self._mapping.keys()] + list( - self._extra_content.items() - ) + return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()) def __iter__(self): - return iter( - list(self._mapping.keys()) + list(self._extra_content.keys()) - ) + return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) def __contains__(self, item): return item in self._mapping or item in self._extra_content @@ -778,9 +771,7 @@ def register(self, key, value, exist_ok=False): Register a new configuration in this mapping. """ if key in self._mapping.keys() and not exist_ok: - raise ValueError( - f"'{key}' is already used by a Transformers config, pick another name." - ) + raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.") self._extra_content[key] = value @@ -807,9 +798,7 @@ def _initialize(self): for model_type, map_name in self._mapping.items(): module_name = model_type_to_module_name(model_type) - module = importlib.import_module( - f".{module_name}", "transformers.models" - ) + module = importlib.import_module(f".{module_name}", "transformers.models") mapping = getattr(module, map_name) self._data.update(mapping) @@ -848,15 +837,10 @@ def _get_class_name(model_class: Union[str, List[str]]): def _list_model_options(indent, config_to_class=None, use_model_types=True): if config_to_class is None and not use_model_types: - raise ValueError( - "Using `use_model_types=False` requires a `config_to_class` dictionary." - ) + raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.") if use_model_types: if config_to_class is None: - model_type_to_name = { - model_type: f"[`{config}`]" - for model_type, config in CONFIG_MAPPING_NAMES.items() - } + model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()} else: model_type_to_name = { model_type: _get_class_name(model_class) @@ -874,8 +858,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): if config in CONFIG_MAPPING_NAMES } config_to_model_name = { - config: MODEL_NAMES_MAPPING[model_type] - for model_type, config in CONFIG_MAPPING_NAMES.items() + config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items() } lines = [ f"{indent}- [`{config_name}`] configuration class:" @@ -885,9 +868,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): return "\n".join(lines) -def replace_list_option_in_docstrings( - config_to_class=None, use_model_types=True -): +def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True): def docstring_decorator(fn): docstrings = fn.__doc__ if docstrings is None: @@ -895,10 +876,7 @@ def docstring_decorator(fn): return fn lines = docstrings.split("\n") i = 0 - while ( - i < len(lines) - and re.search(r"^(\s*)List options\s*$", lines[i]) is None - ): + while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None: i += 1 if i < len(lines): indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0] @@ -1045,17 +1023,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): trust_remote_code = kwargs.pop("trust_remote_code", None) code_revision = kwargs.pop("code_revision", None) - config_dict, unused_kwargs = PretrainedConfig.get_config_dict( - pretrained_model_name_or_path, **kwargs - ) - has_remote_code = ( - "auto_map" in config_dict - and "AutoConfig" in config_dict["auto_map"] - ) - has_local_code = ( - "model_type" in config_dict - and config_dict["model_type"] in CONFIG_MAPPING - ) + config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) + has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"] + has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING trust_remote_code = resolve_trust_remote_code( trust_remote_code, pretrained_model_name_or_path, @@ -1073,9 +1043,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): ) if os.path.isdir(pretrained_model_name_or_path): config_class.register_for_auto_class() - return config_class.from_pretrained( - pretrained_model_name_or_path, **kwargs - ) + return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) elif "model_type" in config_dict: try: config_class = CONFIG_MAPPING[config_dict["model_type"]] @@ -1089,13 +1057,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): else: # Fallback: use pattern matching on the string. # We go from longer names to shorter names to catch roberta before bert (for instance) - for pattern in sorted( - CONFIG_MAPPING.keys(), key=len, reverse=True - ): + for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True): if pattern in str(pretrained_model_name_or_path): - return CONFIG_MAPPING[pattern].from_dict( - config_dict, **unused_kwargs - ) + return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs) raise ValueError( f"Unrecognized model in {pretrained_model_name_or_path}. " @@ -1112,10 +1076,7 @@ def register(model_type, config, exist_ok=False): model_type (`str`): The model type like "bert" or "gpt". config ([`PretrainedConfig`]): The config to register. """ - if ( - issubclass(config, PretrainedConfig) - and config.model_type != model_type - ): + if issubclass(config, PretrainedConfig) and config.model_type != model_type: raise ValueError( "The config you are passing has a `model_type` attribute that is not consistent with the model type " f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they " diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 1d59744839ff..b7202ff578d9 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -45,15 +45,14 @@ replace_list_option_in_docstrings, ) + logger = logging.get_logger(__name__) if TYPE_CHECKING: # This significantly improves completion suggestion performance when # the transformers package is used with Microsoft's Pylance language server. - IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[ - str, Tuple[Optional[str], Optional[str]] - ] = OrderedDict() + IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict() else: IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( [ @@ -165,11 +164,7 @@ slow_image_processor_class = None # If the fast image processor is not defined, or torchvision is not available, we set it to None - if ( - not fast_image_processor_class - or fast_image_processor_class[0] is None - or not is_torchvision_available() - ): + if not fast_image_processor_class or fast_image_processor_class[0] is None or not is_torchvision_available(): fast_image_processor_class = None else: fast_image_processor_class = fast_image_processor_class[0] @@ -179,9 +174,7 @@ fast_image_processor_class, ) -IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES -) +IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES) def image_processor_class_from_name(class_name: str): @@ -192,9 +185,7 @@ def image_processor_class_from_name(class_name: str): if class_name in extractors: module_name = model_type_to_module_name(module_name) - module = importlib.import_module( - f".{module_name}", "transformers.models" - ) + module = importlib.import_module(f".{module_name}", "transformers.models") try: return getattr(module, class_name) except AttributeError: @@ -290,9 +281,7 @@ def get_image_processor_config( FutureWarning, ) if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") token = use_auth_token resolved_config_file = get_file_from_repo( @@ -431,33 +420,21 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): trust_remote_code = kwargs.pop("trust_remote_code", None) kwargs["_from_auto"] = True - config_dict, _ = ImageProcessingMixin.get_image_processor_dict( - pretrained_model_name_or_path, **kwargs - ) + config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) image_processor_class = config_dict.get("image_processor_type", None) image_processor_auto_map = None if "AutoImageProcessor" in config_dict.get("auto_map", {}): - image_processor_auto_map = config_dict["auto_map"][ - "AutoImageProcessor" - ] + image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"] # If we still don't have the image processor class, check if we're loading from a previous feature extractor config # and if so, infer the image processor class from there. if image_processor_class is None and image_processor_auto_map is None: - feature_extractor_class = config_dict.pop( - "feature_extractor_type", None - ) + feature_extractor_class = config_dict.pop("feature_extractor_type", None) if feature_extractor_class is not None: - image_processor_class = feature_extractor_class.replace( - "FeatureExtractor", "ImageProcessor" - ) + image_processor_class = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor") if "AutoFeatureExtractor" in config_dict.get("auto_map", {}): - feature_extractor_auto_map = config_dict["auto_map"][ - "AutoFeatureExtractor" - ] - image_processor_auto_map = feature_extractor_auto_map.replace( - "FeatureExtractor", "ImageProcessor" - ) + feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"] + image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor") # If we don't find the image processor class in the image processor config, let's try the model config. if image_processor_class is None and image_processor_auto_map is None: @@ -468,16 +445,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): **kwargs, ) # It could be in `config.image_processor_type`` - image_processor_class = getattr( - config, "image_processor_type", None - ) - if ( - hasattr(config, "auto_map") - and "AutoImageProcessor" in config.auto_map - ): - image_processor_auto_map = config.auto_map[ - "AutoImageProcessor" - ] + image_processor_class = getattr(config, "image_processor_type", None) + if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map: + image_processor_auto_map = config.auto_map["AutoImageProcessor"] if image_processor_class is not None: # Update class name to reflect the use_fast option. If class is not found, None is returned. @@ -486,15 +456,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): image_processor_class += "Fast" elif not use_fast and image_processor_class.endswith("Fast"): image_processor_class = image_processor_class[:-4] - image_processor_class = image_processor_class_from_name( - image_processor_class - ) + image_processor_class = image_processor_class_from_name(image_processor_class) has_remote_code = image_processor_auto_map is not None - has_local_code = ( - image_processor_class is not None - or type(config) in IMAGE_PROCESSOR_MAPPING - ) + has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING trust_remote_code = resolve_trust_remote_code( trust_remote_code, pretrained_model_name_or_path, @@ -502,25 +467,19 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): has_remote_code, ) - if image_processor_auto_map is not None and not isinstance( - image_processor_auto_map, tuple - ): + if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple): # In some configs, only the slow image processor class is stored image_processor_auto_map = (image_processor_auto_map, None) if has_remote_code and trust_remote_code: if not use_fast and image_processor_auto_map[1] is not None: - _warning_fast_image_processor_available( - image_processor_auto_map[1] - ) + _warning_fast_image_processor_available(image_processor_auto_map[1]) if use_fast and image_processor_auto_map[1] is not None: class_ref = image_processor_auto_map[1] else: class_ref = image_processor_auto_map[0] - image_processor_class = get_class_from_dynamic_module( - class_ref, pretrained_model_name_or_path, **kwargs - ) + image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) _ = kwargs.pop("code_revision", None) if os.path.isdir(pretrained_model_name_or_path): image_processor_class.register_for_auto_class() @@ -531,26 +490,16 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): elif type(config) in IMAGE_PROCESSOR_MAPPING: image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)] - image_processor_class_py, image_processor_class_fast = ( - image_processor_tuple - ) + image_processor_class_py, image_processor_class_fast = image_processor_tuple if not use_fast and image_processor_class_fast is not None: - _warning_fast_image_processor_available( - image_processor_class_fast - ) + _warning_fast_image_processor_available(image_processor_class_fast) - if image_processor_class_fast and ( - use_fast or image_processor_class_py is None - ): - return image_processor_class_fast.from_pretrained( - pretrained_model_name_or_path, *inputs, **kwargs - ) + if image_processor_class_fast and (use_fast or image_processor_class_py is None): + return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) else: if image_processor_class_py is not None: - return image_processor_class_py.from_pretrained( - pretrained_model_name_or_path, *inputs, **kwargs - ) + return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) else: raise ValueError( "This image processor cannot be instantiated. Please make sure you have `Pillow` installed." @@ -580,41 +529,25 @@ def register( """ if image_processor_class is not None: if slow_image_processor_class is not None: - raise ValueError( - "Cannot specify both image_processor_class and slow_image_processor_class" - ) + raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class") warnings.warn( "The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead", FutureWarning, ) slow_image_processor_class = image_processor_class - if ( - slow_image_processor_class is None - and fast_image_processor_class is None - ): - raise ValueError( - "You need to specify either slow_image_processor_class or fast_image_processor_class" - ) - if slow_image_processor_class is not None and issubclass( - slow_image_processor_class, BaseImageProcessorFast - ): - raise ValueError( - "You passed a fast image processor in as the `slow_image_processor_class`." - ) - if fast_image_processor_class is not None and issubclass( - fast_image_processor_class, BaseImageProcessor - ): - raise ValueError( - "You passed a slow image processor in as the `fast_image_processor_class`." - ) + if slow_image_processor_class is None and fast_image_processor_class is None: + raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class") + if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast): + raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.") + if fast_image_processor_class is not None and issubclass(fast_image_processor_class, BaseImageProcessor): + raise ValueError("You passed a slow image processor in as the `fast_image_processor_class`.") if ( slow_image_processor_class is not None and fast_image_processor_class is not None and issubclass(fast_image_processor_class, BaseImageProcessorFast) - and fast_image_processor_class.slow_image_processor_class - != slow_image_processor_class + and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class ): raise ValueError( "The fast processor class you are passing has a `slow_image_processor_class` attribute that is not " @@ -625,9 +558,7 @@ def register( # Avoid resetting a set slow/fast image processor if we are passing just the other ones. if config_class in IMAGE_PROCESSOR_MAPPING._extra_content: - existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[ - config_class - ] + existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[config_class] if slow_image_processor_class is None: slow_image_processor_class = existing_slow if fast_image_processor_class is None: diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a96085dab986..e0cd0c9423a1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -26,6 +26,7 @@ ) from .configuration_auto import CONFIG_MAPPING_NAMES + logger = logging.get_logger(__name__) MODEL_MAPPING_NAMES = OrderedDict( @@ -1438,15 +1439,9 @@ ) MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) -MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES -) -MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES -) -MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -) +MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) +MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) +MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES ) @@ -1472,9 +1467,7 @@ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES ) -MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES -) +MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES ) @@ -1484,24 +1477,16 @@ MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES ) -MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES -) -MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES -) +MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) +MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES) MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES ) -MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES -) +MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES ) -MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES -) +MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES ) @@ -1517,51 +1502,35 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES ) -MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES -) +MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES ) MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES ) -MODEL_FOR_CTC_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES -) -MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES -) +MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) +MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES ) -MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES -) +MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES ) -MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES -) +MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES) -MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES -) +MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) -MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES -) +MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES ) -MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES -) +MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES @@ -1571,9 +1540,7 @@ CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES ) -MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES -) +MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES) class AutoModelForMaskGeneration(_BaseAutoModelClass): @@ -1603,9 +1570,7 @@ class AutoModelForPreTraining(_BaseAutoModelClass): _model_mapping = MODEL_FOR_PRETRAINING_MAPPING -AutoModelForPreTraining = auto_class_update( - AutoModelForPreTraining, head_doc="pretraining" -) +AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") # Private on purpose, the public class will add the deprecation warnings. @@ -1613,27 +1578,21 @@ class _AutoModelWithLMHead(_BaseAutoModelClass): _model_mapping = MODEL_WITH_LM_HEAD_MAPPING -_AutoModelWithLMHead = auto_class_update( - _AutoModelWithLMHead, head_doc="language modeling" -) +_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") class AutoModelForCausalLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING -AutoModelForCausalLM = auto_class_update( - AutoModelForCausalLM, head_doc="causal language modeling" -) +AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") class AutoModelForMaskedLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MASKED_LM_MAPPING -AutoModelForMaskedLM = auto_class_update( - AutoModelForMaskedLM, head_doc="masked language modeling" -) +AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") class AutoModelForSeq2SeqLM(_BaseAutoModelClass): @@ -1660,9 +1619,7 @@ class AutoModelForQuestionAnswering(_BaseAutoModelClass): _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING -AutoModelForQuestionAnswering = auto_class_update( - AutoModelForQuestionAnswering, head_doc="question answering" -) +AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): @@ -1702,18 +1659,14 @@ class AutoModelForTokenClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING -AutoModelForTokenClassification = auto_class_update( - AutoModelForTokenClassification, head_doc="token classification" -) +AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") class AutoModelForMultipleChoice(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING -AutoModelForMultipleChoice = auto_class_update( - AutoModelForMultipleChoice, head_doc="multiple choice" -) +AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") class AutoModelForNextSentencePrediction(_BaseAutoModelClass): @@ -1729,9 +1682,7 @@ class AutoModelForImageClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING -AutoModelForImageClassification = auto_class_update( - AutoModelForImageClassification, head_doc="image classification" -) +AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): @@ -1748,9 +1699,7 @@ class AutoModelForImageSegmentation(_BaseAutoModelClass): _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING -AutoModelForImageSegmentation = auto_class_update( - AutoModelForImageSegmentation, head_doc="image segmentation" -) +AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") class AutoModelForSemanticSegmentation(_BaseAutoModelClass): @@ -1784,9 +1733,7 @@ class AutoModelForObjectDetection(_BaseAutoModelClass): _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING -AutoModelForObjectDetection = auto_class_update( - AutoModelForObjectDetection, head_doc="object detection" -) +AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass): @@ -1802,54 +1749,42 @@ class AutoModelForDepthEstimation(_BaseAutoModelClass): _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING -AutoModelForDepthEstimation = auto_class_update( - AutoModelForDepthEstimation, head_doc="depth estimation" -) +AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation") class AutoModelForVideoClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING -AutoModelForVideoClassification = auto_class_update( - AutoModelForVideoClassification, head_doc="video classification" -) +AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification") class AutoModelForVision2Seq(_BaseAutoModelClass): _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING -AutoModelForVision2Seq = auto_class_update( - AutoModelForVision2Seq, head_doc="vision-to-text modeling" -) +AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling") class AutoModelForImageTextToText(_BaseAutoModelClass): _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING -AutoModelForImageTextToText = auto_class_update( - AutoModelForImageTextToText, head_doc="image-text-to-text modeling" -) +AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling") class AutoModelForAudioClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING -AutoModelForAudioClassification = auto_class_update( - AutoModelForAudioClassification, head_doc="audio classification" -) +AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") class AutoModelForCTC(_BaseAutoModelClass): _model_mapping = MODEL_FOR_CTC_MAPPING -AutoModelForCTC = auto_class_update( - AutoModelForCTC, head_doc="connectionist temporal classification" -) +AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): @@ -1888,18 +1823,14 @@ class AutoBackbone(_BaseAutoBackboneClass): _model_mapping = MODEL_FOR_BACKBONE_MAPPING -AutoModelForAudioXVector = auto_class_update( - AutoModelForAudioXVector, head_doc="audio retrieval via x-vector" -) +AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") class AutoModelForMaskedImageModeling(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING -AutoModelForMaskedImageModeling = auto_class_update( - AutoModelForMaskedImageModeling, head_doc="masked image modeling" -) +AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling") class AutoModelWithLMHead(_AutoModelWithLMHead): @@ -1914,15 +1845,11 @@ def from_config(cls, config): return super().from_config(config) @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path, *model_args, **kwargs - ): + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): warnings.warn( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) - return super().from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs - ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2643fb70ac15..2a338a9dc5d4 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -10201,9 +10201,7 @@ def get_cosine_schedule_with_warmup(*args, **kwargs): def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): - requires_backends( - get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"] - ) + requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) def get_inverse_sqrt_schedule(*args, **kwargs): From 52f2173de84fdf7967494e128c5777dc3ec99ad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sat, 16 Nov 2024 19:47:41 +0100 Subject: [PATCH 24/44] adjust modeling to modular code --- src/transformers/models/ijepa/__init__.py | 4 +- .../models/ijepa/convert_ijepa_to_hf.py | 165 ++++++------- .../models/ijepa/modeling_ijepa.py | 216 ++++++++---------- .../models/ijepa/modular_ijepa.py | 205 +++++++++++++++++ 4 files changed, 373 insertions(+), 217 deletions(-) create mode 100644 src/transformers/models/ijepa/modular_ijepa.py diff --git a/src/transformers/models/ijepa/__init__.py b/src/transformers/models/ijepa/__init__.py index 50ab72784d98..adaefe9ae43c 100644 --- a/src/transformers/models/ijepa/__init__.py +++ b/src/transformers/models/ijepa/__init__.py @@ -20,7 +20,7 @@ ) -_import_structure = {"configuration_ijepa": ["IJepaConfig", "IJepaOnnxConfig"]} +_import_structure = {"configuration_ijepa": ["IJepaConfig"]} try: if not is_torch_available(): @@ -35,7 +35,7 @@ ] if TYPE_CHECKING: - from .configuration_ijepa import IJepaConfig, IJepaOnnxConfig + from .configuration_ijepa import IJepaConfig try: if not is_torch_available(): diff --git a/src/transformers/models/ijepa/convert_ijepa_to_hf.py b/src/transformers/models/ijepa/convert_ijepa_to_hf.py index 0546aae478e1..c2ed1611837c 100644 --- a/src/transformers/models/ijepa/convert_ijepa_to_hf.py +++ b/src/transformers/models/ijepa/convert_ijepa_to_hf.py @@ -18,6 +18,8 @@ """ import argparse +import gc +import re from pathlib import Path import requests @@ -35,91 +37,57 @@ logging.set_verbosity_info() logger = logging.get_logger(__name__) +# fmt: off +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Projection layer + position embeddings + r"pos_embed": r"embeddings.position_embeddings", + r"patch_embed.proj.weight": r"embeddings.patch_embeddings.projection.weight", + r"patch_embed.proj.bias": r"embeddings.patch_embeddings.projection.bias", + + # Encoder layers: Layernorms, Attention, Feedforward layers + r"blocks.(\d+).norm1.weight": r"encoder.layer.\1.layernorm_before.weight", + r"blocks.(\d+).norm1.bias": r"encoder.layer.\1.layernorm_before.bias", + r"blocks.(\d+).attn.proj.weight": r"encoder.layer.\1.attention.output.dense.weight", + r"blocks.(\d+).attn.proj.bias": r"encoder.layer.\1.attention.output.dense.bias", + r"blocks.(\d+).norm2.weight": r"encoder.layer.\1.layernorm_after.weight", + r"blocks.(\d+).norm2.bias": r"encoder.layer.\1.layernorm_after.bias", + r"blocks.(\d+).mlp.fc1.weight": r"encoder.layer.\1.intermediate.dense.weight", + r"blocks.(\d+).mlp.fc1.bias": r"encoder.layer.\1.intermediate.dense.bias", + r"blocks.(\d+).mlp.fc2.weight": r"encoder.layer.\1.output.dense.weight", + r"blocks.(\d+).mlp.fc2.bias": r"encoder.layer.\1.output.dense.bias", + + # Layernorm + pooler + r"norm.weight": r"layernorm.weight", + r"norm.bias": r"layernorm.bias", +} +# fmt: on + + +def convert_old_keys_to_new_keys(state_dict_keys: dict = None): + """ + Converts old keys to new keys using the mapping and dynamically removes the 'ijepa.' prefix if necessary. -# here we list all keys to be renamed (original name on the left, our name on the right) -def create_rename_keys(config): - rename_keys = [] - - # projection layer + position embeddings - rename_keys.append(("pos_embed", "ijepa.embeddings.position_embeddings")) - rename_keys.append(("patch_embed.proj.weight", "ijepa.embeddings.patch_embeddings.projection.weight")) - rename_keys.append(("patch_embed.proj.bias", "ijepa.embeddings.patch_embeddings.projection.bias")) + Args: + state_dict_keys (dict): The keys from the state_dict to convert. - for i in range(config.num_hidden_layers): - # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms - rename_keys.append( - ( - f"blocks.{i}.norm1.weight", - f"ijepa.encoder.layer.{i}.layernorm_before.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.norm1.bias", - f"ijepa.encoder.layer.{i}.layernorm_before.bias", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.attn.proj.weight", - f"ijepa.encoder.layer.{i}.attention.output.dense.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.attn.proj.bias", - f"ijepa.encoder.layer.{i}.attention.output.dense.bias", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.norm2.weight", - f"ijepa.encoder.layer.{i}.layernorm_after.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.norm2.bias", - f"ijepa.encoder.layer.{i}.layernorm_after.bias", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc1.weight", - f"ijepa.encoder.layer.{i}.intermediate.dense.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc1.bias", - f"ijepa.encoder.layer.{i}.intermediate.dense.bias", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc2.weight", - f"ijepa.encoder.layer.{i}.output.dense.weight", - ) - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc2.bias", - f"ijepa.encoder.layer.{i}.output.dense.bias", - ) - ) + Returns: + dict: A mapping from old keys to new keys. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text - # layernorm + pooler - rename_keys.extend( - [ - ("norm.weight", "layernorm.weight"), - ("norm.bias", "layernorm.bias"), - ] - ) + # Apply regex-based mapping + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # Skip the key + continue + new_text = re.sub(pattern, replacement, new_text) - # if just the base model, we should remove "ijepa" from all keys that start with "ijepa" - rename_keys = [(pair[0], pair[1][6:]) if pair[1].startswith("ijepa") else pair for pair in rename_keys] + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) - return rename_keys + return output_dict # we split up the matrix of each encoder layer into queries, keys and values @@ -184,7 +152,7 @@ def get_ijepa_config(model_name): @torch.no_grad() -def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits): +def write_model(model_name, output_dir, safe_serialization, push_to_hub, verify_logits): """ Copy/paste/tweak model's weights to our IJEPA structure. """ @@ -206,10 +174,9 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, # Rename keys state_dict = original_state_dict.copy() - remove_classification_head_(state_dict) - rename_keys = create_rename_keys(config) - for src, dest in rename_keys: - rename_key(state_dict, src, dest) + new_keys = convert_old_keys_to_new_keys(state_dict.keys()) + for old_key, new_key in new_keys.items(): + rename_key(state_dict, old_key, new_key) read_in_q_k_v(state_dict, config) # load HuggingFace model @@ -234,10 +201,10 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, atol=1e-4, ) - if pytorch_dump_folder_path is not None: - Path(pytorch_dump_folder_path).mkdir(exist_ok=True) - print(f"Saving model {model_name} to {pytorch_dump_folder_path}") - model.save_pretrained(pytorch_dump_folder_path) + if output_dir: + Path(output_dir).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {output_dir}") + model.save_pretrained(output_dir, safe_serialization=safe_serialization) if push_to_hub: model_name_to_hf_name = { @@ -249,8 +216,15 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, name = model_name_to_hf_name[model_name] model.push_to_hub(f"jmtzt/{name}", use_temp_dir=True) + if output_dir: + del model, state_dict + gc.collect() + print("Reloading the model to check if it's saved correctly.") + IJepaModel.from_pretrained(output_dir, device_map="auto") + print("Model reloaded successfully.") -if __name__ == "__main__": + +def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( @@ -266,11 +240,14 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, help="Name of the model you'd like to convert.", ) parser.add_argument( - "--pytorch_dump_folder_path", + "--output_dir", default=None, type=str, help="Path to the output PyTorch model directory.", ) + parser.add_argument( + "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`." + ) parser.add_argument( "--push_to_hub", action="store_true", @@ -282,4 +259,8 @@ def convert_ijepa_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, parser.set_defaults() args = parser.parse_args() - convert_ijepa_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits) + write_model(args.model_name, args.output_dir, args.safe_serialization, args.push_to_hub, args.verify_logits) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 0a93255c34ef..0d83949309c0 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -1,79 +1,58 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch I-JEPA model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ijepa/modular_ijepa.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ijepa.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import collections.abc import math from typing import Dict, List, Optional, Set, Tuple, Union import torch -import torch.utils.checkpoint -from torch import nn +import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, - ImageClassifierOutput, -) +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ( - find_pruneable_heads_and_indices, - prune_linear_layer, -) -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_ijepa import IJepaConfig logger = logging.get_logger(__name__) -# General docstring -_CONFIG_FOR_DOC = "IJepaConfig" - # Base docstring -_CHECKPOINT_FOR_DOC = "facebook/I-JEPA" -_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] +_CHECKPOINT_FOR_DOC = "google/ijepa-base-patch16-224-in21k" -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "jmtzt/ijepa_huge_patch14_1k" -_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" +# General docstring +_CONFIG_FOR_DOC = "IJepaConfig" class IJepaEmbeddings(nn.Module): """ - Construct the position and patch embeddings. Optionally, also the mask token. + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. """ def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None: super().__init__() - self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = IJepaPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size self.config = config def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. @@ -116,9 +95,10 @@ def forward( self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, ) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape - embeddings = self.patch_embeddings(pixel_values) + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) if bool_masked_pos is not None: seq_length = embeddings.shape[1] @@ -128,7 +108,10 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask # add positional encoding to each token - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -157,23 +140,23 @@ def __init__(self, config): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings -# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->IJepa class IJepaSelfAttention(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -234,15 +217,30 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->IJepa class IJepaSdpaSelfAttention(IJepaSelfAttention): def __init__(self, config: IJepaConfig) -> None: super().__init__(config) self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states: torch.FloatTensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`IJepaSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + ) + mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) @@ -266,7 +264,6 @@ def forward( return context_layer, None -# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->IJepa class IJepaSelfOutput(nn.Module): """ The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the @@ -285,7 +282,6 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->IJepa class IJepaAttention(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -325,14 +321,12 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->IJepa class IJepaSdpaAttention(IJepaAttention): def __init__(self, config: IJepaConfig) -> None: super().__init__(config) self.attention = IJepaSdpaSelfAttention(config) -# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->IJepa class IJepaIntermediate(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -349,7 +343,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->IJepa class IJepaOutput(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -365,13 +358,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -IJepa_ATTENTION_CLASSES = { +IJEPA_ATTENTION_CLASSES = { "eager": IJepaAttention, "sdpa": IJepaSdpaAttention, } -# Copied from transformers.models.vit.modeling_vit.ViTLayer with VIT->IJepa,ViT->IJepa class IJepaLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -379,7 +371,7 @@ def __init__(self, config: IJepaConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = IJepa_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = IJepaIntermediate(config) self.output = IJepaOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -414,7 +406,6 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->IJepa class IJepaEncoder(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -500,6 +491,9 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No ).to(module.position_embeddings.dtype) +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + + IJEPA_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and @@ -509,22 +503,14 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No config ([`IJepaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. - - add_pooling_layer (bool, *optional*, defaults to `True`): Whether to include a pooling layer in the model. If set to `True`, the model will include a pooling layer that can be used to extract a pooled output representation of the hidden states. If set to `False`, the pooling layer will be omitted. - - use_mask_token (bool, *optional*, defaults to `False`): Whether to use a mask token in the embeddings layer. If set to `True`, a special token will be used to mask certain inputs during training. If set to `False`, no mask token will be used, and the embeddings layer will function without masking capability. - """ IJEPA_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`IJepaImageProcessor.__call__`] for details. - bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: @@ -534,28 +520,26 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. - output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( - "The bare I-JEPA Model transformer outputting raw hidden-states without any specific head on top.", + "The bare IJepa Model transformer outputting raw hidden-states without any specific head on top.", IJEPA_START_DOCSTRING, ) class IJepaModel(IJepaPreTrainedModel): def __init__(self, config: IJepaConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): super().__init__(config) self.config = config - self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) self.encoder = IJepaEncoder(config) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pooler = IJepaPooler(config) if add_pooling_layer else None @@ -588,8 +572,13 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -612,8 +601,7 @@ def forward( pixel_values = pixel_values.to(expected_dtype) embedding_output = self.embeddings( - pixel_values, - bool_masked_pos=bool_masked_pos, + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) encoder_outputs = self.encoder( @@ -639,7 +627,6 @@ def forward( ) -# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->IJepa class IJepaPooler(nn.Module): def __init__(self, config: IJepaConfig): super().__init__() @@ -655,54 +642,25 @@ def forward(self, hidden_states): return pooled_output -IJEPA_FOR_IMAGE_CLASSIFICATION_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it - as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`IJepaConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -IJEPA_FOR_IMAGE_CLASSIFICATION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] - for details. - - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" @add_start_docstrings( """ - I-JEPA Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the average pooling of the output) e.g. for ImageNet. + IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + """, - IJEPA_FOR_IMAGE_CLASSIFICATION_START_DOCSTRING, + IJEPA_START_DOCSTRING, ) class IJepaForImageClassification(IJepaPreTrainedModel): def __init__(self, config: IJepaConfig) -> None: @@ -713,11 +671,12 @@ def __init__(self, config: IJepaConfig) -> None: # Classifier head self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + self.vit = IJepaModel(config) # Initialize weights and apply final processing self.post_init() - @add_start_docstrings_to_model_forward(IJEPA_FOR_IMAGE_CLASSIFICATION_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_IMAGE_CLASS_CHECKPOINT, output_type=ImageClassifierOutput, @@ -731,8 +690,15 @@ def forward( labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ijepa( @@ -740,6 +706,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -782,3 +749,6 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +__all__ = ["IJepaPreTrainedModel", "IJepaModel", "IJepaForImageClassification"] diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py new file mode 100644 index 000000000000..1870524d38a4 --- /dev/null +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -0,0 +1,205 @@ +import math +from typing import Optional, Union + +import torch +import torch.nn as nn + +from transformers.models.ijepa.configuration_ijepa import IJepaConfig + +from ...modeling_utils import PreTrainedModel +from ..vit.modeling_vit import ( + ViTAttention, + ViTEmbeddings, + ViTEncoder, + ViTForImageClassification, + ViTIntermediate, + ViTLayer, + ViTModel, + ViTPatchEmbeddings, + ViTPooler, + ViTSdpaAttention, + ViTSdpaSelfAttention, + ViTSelfAttention, + ViTSelfOutput, +) + + +class IJepaEmbeddings(ViTEmbeddings): + def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None: + super().__init__(config, use_mask_token) + # Remove cls_token from IJepaEmbeddings, as it is not used in the model + del self.cls_token + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, + int(math.sqrt(num_positions)), + int(math.sqrt(num_positions)), + dim, + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + h0 / math.sqrt(num_positions), + w0 / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class IJepaPatchEmbeddings(ViTPatchEmbeddings): + pass + + +class IJepaSelfAttention(ViTSelfAttention): + pass + + +class IJepaSdpaSelfAttention(ViTSdpaSelfAttention): + pass + + +class IJepaSelfOutput(ViTSelfOutput): + pass + + +class IJepaAttention(ViTAttention): + pass + + +class IJepaSdpaAttention(ViTSdpaAttention): + pass + + +class IJepaIntermediate(ViTIntermediate): + pass + + +IJepa_ATTENTION_CLASSES = { + "eager": IJepaAttention, + "sdpa": IJepaSdpaAttention, +} + + +class IJepaLayer(ViTLayer): + pass + + +class IJepaEncoder(ViTEncoder): + pass + + +class IJepaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = IJepaConfig + base_model_prefix = "ijepa" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] + _supports_sdpa = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, IJepaEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class IJepaModel(IJepaPreTrainedModel, ViTModel): + def __init__(self, config: IJepaConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + super().__init__(config) + self.config = config + self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = IJepaEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = IJepaPooler(config) if add_pooling_layer else None + # Initialize weights and apply final processing + self.post_init() + + +class IJepaPooler(ViTPooler): + pass + + +class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification): + def __init__(self, config: IJepaConfig): + super().__init__(config) + self.vit = IJepaModel(config) + self.post_init() + + +__all__ = [ + "IJepaPreTrainedModel", + "IJepaModel", + "IJepaForImageClassification", +] From b13a24e9a332d6c6431983b036d55423d806b49b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sat, 16 Nov 2024 19:51:54 +0100 Subject: [PATCH 25/44] add IJepaModel to objects to ignore in docstring checks --- utils/check_docstrings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 0be960f4a33e..a2ea05edce80 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -331,6 +331,7 @@ "IBertModel", "IdeficsConfig", "IdeficsProcessor", + "IJepaModel", "ImageClassificationPipeline", "ImageFeatureExtractionPipeline", "ImageGPTConfig", From 2b154ce7120b2fdac6dff3f8fc90f42d27bb4b23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Sat, 16 Nov 2024 20:00:34 +0100 Subject: [PATCH 26/44] [run-slow] ijepa From 3f0c027051d7f6685dcddc11b69df36fc1f2fe3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 09:52:40 +0100 Subject: [PATCH 27/44] fix formatting issues --- src/transformers/__init__.py | 666 ++++++++++++++++++++++------------- 1 file changed, 419 insertions(+), 247 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e62093dd7cfb..ba184174f1d2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -540,10 +540,7 @@ "LlavaNextVideoConfig", "LlavaNextVideoProcessor", ], - "models.llava_onevision": [ - "LlavaOnevisionConfig", - "LlavaOnevisionProcessor", - ], + "models.llava_onevision": ["LlavaOnevisionConfig", "LlavaOnevisionProcessor"], "models.longformer": [ "LongformerConfig", "LongformerTokenizer", @@ -1196,13 +1193,7 @@ _import_structure["models.deprecated.efficientformer"].append("EfficientFormerImageProcessor") _import_structure["models.deprecated.tvlt"].append("TvltImageProcessor") _import_structure["models.deprecated.vit_hybrid"].extend(["ViTHybridImageProcessor"]) - _import_structure["models.detr"].extend( - [ - "DetrFeatureExtractor", - "DetrImageProcessor", - "DetrImageProcessorFast", - ] - ) + _import_structure["models.detr"].extend(["DetrFeatureExtractor", "DetrImageProcessor", "DetrImageProcessorFast"]) _import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"]) _import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"]) _import_structure["models.efficientnet"].append("EfficientNetImageProcessor") @@ -2681,12 +2672,7 @@ ] ) _import_structure["models.marian"].extend( - [ - "MarianForCausalLM", - "MarianModel", - "MarianMTModel", - "MarianPreTrainedModel", - ] + ["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"] ) _import_structure["models.markuplm"].extend( [ @@ -3888,10 +3874,7 @@ "TFTopPLogitsWarper", ] ) - _import_structure["keras_callbacks"] = [ - "KerasMetricCallback", - "PushToHubCallback", - ] + _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] _import_structure["modeling_tf_outputs"] = [] _import_structure["modeling_tf_utils"] = [ "TFPreTrainedModel", @@ -4279,19 +4262,10 @@ ) _import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]) _import_structure["models.mbart"].extend( - [ - "TFMBartForConditionalGeneration", - "TFMBartModel", - "TFMBartPreTrainedModel", - ] + ["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"] ) _import_structure["models.mistral"].extend( - [ - "TFMistralForCausalLM", - "TFMistralForSequenceClassification", - "TFMistralModel", - "TFMistralPreTrainedModel", - ] + ["TFMistralForCausalLM", "TFMistralForSequenceClassification", "TFMistralModel", "TFMistralPreTrainedModel"] ) _import_structure["models.mobilebert"].extend( [ @@ -4589,7 +4563,9 @@ if not is_torchaudio_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torchaudio_objects + from .utils import ( + dummy_torchaudio_objects, + ) _import_structure["utils.dummy_torchaudio_objects"] = [ name for name in dir(dummy_torchaudio_objects) if not name.startswith("_") @@ -4789,11 +4765,7 @@ _import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel") _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]) _import_structure["models.gpt_neo"].extend( - [ - "FlaxGPTNeoForCausalLM", - "FlaxGPTNeoModel", - "FlaxGPTNeoPreTrainedModel", - ] + ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] ) _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) _import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]) @@ -4828,13 +4800,7 @@ "FlaxMistralPreTrainedModel", ] ) - _import_structure["models.mt5"].extend( - [ - "FlaxMT5EncoderModel", - "FlaxMT5ForConditionalGeneration", - "FlaxMT5Model", - ] - ) + _import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) _import_structure["models.opt"].extend( [ "FlaxOPTForCausalLM", @@ -4909,13 +4875,7 @@ ) _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) - _import_structure["models.vit"].extend( - [ - "FlaxViTForImageClassification", - "FlaxViTModel", - "FlaxViTPreTrainedModel", - ] - ) + _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) _import_structure["models.wav2vec2"].extend( [ "FlaxWav2Vec2ForCTC", @@ -5017,12 +4977,7 @@ from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin # Generation - from .generation import ( - GenerationConfig, - TextIteratorStreamer, - TextStreamer, - WatermarkingConfig, - ) + from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig from .hf_argparser import HfArgumentParser # Integrations @@ -5082,7 +5037,9 @@ AutoProcessor, AutoTokenizer, ) - from .models.autoformer import AutoformerConfig + from .models.autoformer import ( + AutoformerConfig, + ) from .models.bark import ( BarkCoarseConfig, BarkConfig, @@ -5106,10 +5063,18 @@ ) from .models.bertweet import BertweetTokenizer from .models.big_bird import BigBirdConfig - from .models.bigbird_pegasus import BigBirdPegasusConfig - from .models.biogpt import BioGptConfig, BioGptTokenizer + from .models.bigbird_pegasus import ( + BigBirdPegasusConfig, + ) + from .models.biogpt import ( + BioGptConfig, + BioGptTokenizer, + ) from .models.bit import BitConfig - from .models.blenderbot import BlenderbotConfig, BlenderbotTokenizer + from .models.blenderbot import ( + BlenderbotConfig, + BlenderbotTokenizer, + ) from .models.blenderbot_small import ( BlenderbotSmallConfig, BlenderbotSmallTokenizer, @@ -5133,10 +5098,18 @@ BridgeTowerTextConfig, BridgeTowerVisionConfig, ) - from .models.bros import BrosConfig, BrosProcessor + from .models.bros import ( + BrosConfig, + BrosProcessor, + ) from .models.byt5 import ByT5Tokenizer - from .models.camembert import CamembertConfig - from .models.canine import CanineConfig, CanineTokenizer + from .models.camembert import ( + CamembertConfig, + ) + from .models.canine import ( + CanineConfig, + CanineTokenizer, + ) from .models.chameleon import ( ChameleonConfig, ChameleonProcessor, @@ -5175,29 +5148,59 @@ ClvpProcessor, ClvpTokenizer, ) - from .models.codegen import CodeGenConfig, CodeGenTokenizer + from .models.codegen import ( + CodeGenConfig, + CodeGenTokenizer, + ) from .models.cohere import CohereConfig - from .models.conditional_detr import ConditionalDetrConfig - from .models.convbert import ConvBertConfig, ConvBertTokenizer + from .models.conditional_detr import ( + ConditionalDetrConfig, + ) + from .models.convbert import ( + ConvBertConfig, + ConvBertTokenizer, + ) from .models.convnext import ConvNextConfig - from .models.convnextv2 import ConvNextV2Config - from .models.cpmant import CpmAntConfig, CpmAntTokenizer - from .models.ctrl import CTRLConfig, CTRLTokenizer + from .models.convnextv2 import ( + ConvNextV2Config, + ) + from .models.cpmant import ( + CpmAntConfig, + CpmAntTokenizer, + ) + from .models.ctrl import ( + CTRLConfig, + CTRLTokenizer, + ) from .models.cvt import CvtConfig - from .models.dac import DacConfig, DacFeatureExtractor + from .models.dac import ( + DacConfig, + DacFeatureExtractor, + ) from .models.data2vec import ( Data2VecAudioConfig, Data2VecTextConfig, Data2VecVisionConfig, ) from .models.dbrx import DbrxConfig - from .models.deberta import DebertaConfig, DebertaTokenizer - from .models.deberta_v2 import DebertaV2Config - from .models.decision_transformer import DecisionTransformerConfig - from .models.deformable_detr import DeformableDetrConfig + from .models.deberta import ( + DebertaConfig, + DebertaTokenizer, + ) + from .models.deberta_v2 import ( + DebertaV2Config, + ) + from .models.decision_transformer import ( + DecisionTransformerConfig, + ) + from .models.deformable_detr import ( + DeformableDetrConfig, + ) from .models.deit import DeiTConfig from .models.deprecated.deta import DetaConfig - from .models.deprecated.efficientformer import EfficientFormerConfig + from .models.deprecated.efficientformer import ( + EfficientFormerConfig, + ) from .models.deprecated.ernie_m import ErnieMConfig from .models.deprecated.gptsan_japanese import ( GPTSanJapaneseConfig, @@ -5219,9 +5222,14 @@ from .models.deprecated.mmbt import MMBTConfig from .models.deprecated.nat import NatConfig from .models.deprecated.nezha import NezhaConfig - from .models.deprecated.open_llama import OpenLlamaConfig + from .models.deprecated.open_llama import ( + OpenLlamaConfig, + ) from .models.deprecated.qdqbert import QDQBertConfig - from .models.deprecated.realm import RealmConfig, RealmTokenizer + from .models.deprecated.realm import ( + RealmConfig, + RealmTokenizer, + ) from .models.deprecated.retribert import ( RetriBertConfig, RetriBertTokenizer, @@ -5246,14 +5254,24 @@ TvltProcessor, ) from .models.deprecated.van import VanConfig - from .models.deprecated.vit_hybrid import ViTHybridConfig - from .models.deprecated.xlm_prophetnet import XLMProphetNetConfig + from .models.deprecated.vit_hybrid import ( + ViTHybridConfig, + ) + from .models.deprecated.xlm_prophetnet import ( + XLMProphetNetConfig, + ) from .models.depth_anything import DepthAnythingConfig from .models.detr import DetrConfig from .models.dinat import DinatConfig from .models.dinov2 import Dinov2Config - from .models.distilbert import DistilBertConfig, DistilBertTokenizer - from .models.donut import DonutProcessor, DonutSwinConfig + from .models.distilbert import ( + DistilBertConfig, + DistilBertTokenizer, + ) + from .models.donut import ( + DonutProcessor, + DonutSwinConfig, + ) from .models.dpr import ( DPRConfig, DPRContextEncoderTokenizer, @@ -5262,9 +5280,17 @@ DPRReaderTokenizer, ) from .models.dpt import DPTConfig - from .models.efficientnet import EfficientNetConfig - from .models.electra import ElectraConfig, ElectraTokenizer - from .models.encodec import EncodecConfig, EncodecFeatureExtractor + from .models.efficientnet import ( + EfficientNetConfig, + ) + from .models.electra import ( + ElectraConfig, + ElectraTokenizer, + ) + from .models.encodec import ( + EncodecConfig, + EncodecFeatureExtractor, + ) from .models.encoder_decoder import EncoderDecoderConfig from .models.ernie import ErnieConfig from .models.esm import EsmConfig, EsmTokenizer @@ -5286,19 +5312,36 @@ ) from .models.fnet import FNetConfig from .models.focalnet import FocalNetConfig - from .models.fsmt import FSMTConfig, FSMTTokenizer - from .models.funnel import FunnelConfig, FunnelTokenizer + from .models.fsmt import ( + FSMTConfig, + FSMTTokenizer, + ) + from .models.funnel import ( + FunnelConfig, + FunnelTokenizer, + ) from .models.fuyu import FuyuConfig from .models.gemma import GemmaConfig from .models.gemma2 import Gemma2Config - from .models.git import GitConfig, GitProcessor, GitVisionConfig + from .models.git import ( + GitConfig, + GitProcessor, + GitVisionConfig, + ) from .models.glm import GlmConfig from .models.glpn import GLPNConfig - from .models.gpt2 import GPT2Config, GPT2Tokenizer - from .models.gpt_bigcode import GPTBigCodeConfig + from .models.gpt2 import ( + GPT2Config, + GPT2Tokenizer, + ) + from .models.gpt_bigcode import ( + GPTBigCodeConfig, + ) from .models.gpt_neo import GPTNeoConfig from .models.gpt_neox import GPTNeoXConfig - from .models.gpt_neox_japanese import GPTNeoXJapaneseConfig + from .models.gpt_neox_japanese import ( + GPTNeoXJapaneseConfig, + ) from .models.gptj import GPTJConfig from .models.granite import GraniteConfig from .models.granitemoe import GraniteMoeConfig @@ -5315,7 +5358,9 @@ from .models.hiera import HieraConfig from .models.hubert import HubertConfig from .models.ibert import IBertConfig - from .models.idefics import IdeficsConfig + from .models.idefics import ( + IdeficsConfig, + ) from .models.idefics2 import Idefics2Config from .models.idefics3 import Idefics3Config from .models.ijepa import IJepaConfig @@ -5335,8 +5380,14 @@ ) from .models.jamba import JambaConfig from .models.jetmoe import JetMoeConfig - from .models.kosmos2 import Kosmos2Config, Kosmos2Processor - from .models.layoutlm import LayoutLMConfig, LayoutLMTokenizer + from .models.kosmos2 import ( + Kosmos2Config, + Kosmos2Processor, + ) + from .models.layoutlm import ( + LayoutLMConfig, + LayoutLMTokenizer, + ) from .models.layoutlmv2 import ( LayoutLMv2Config, LayoutLMv2FeatureExtractor, @@ -5356,8 +5407,14 @@ from .models.levit import LevitConfig from .models.lilt import LiltConfig from .models.llama import LlamaConfig - from .models.llava import LlavaConfig, LlavaProcessor - from .models.llava_next import LlavaNextConfig, LlavaNextProcessor + from .models.llava import ( + LlavaConfig, + LlavaProcessor, + ) + from .models.llava_next import ( + LlavaNextConfig, + LlavaNextProcessor, + ) from .models.llava_next_video import ( LlavaNextVideoConfig, LlavaNextVideoProcessor, @@ -5366,10 +5423,19 @@ LlavaOnevisionConfig, LlavaOnevisionProcessor, ) - from .models.longformer import LongformerConfig, LongformerTokenizer + from .models.longformer import ( + LongformerConfig, + LongformerTokenizer, + ) from .models.longt5 import LongT5Config - from .models.luke import LukeConfig, LukeTokenizer - from .models.lxmert import LxmertConfig, LxmertTokenizer + from .models.luke import ( + LukeConfig, + LukeTokenizer, + ) + from .models.lxmert import ( + LxmertConfig, + LxmertTokenizer, + ) from .models.m2m_100 import M2M100Config from .models.mamba import MambaConfig from .models.mamba2 import Mamba2Config @@ -5380,26 +5446,62 @@ MarkupLMProcessor, MarkupLMTokenizer, ) - from .models.mask2former import Mask2FormerConfig - from .models.maskformer import MaskFormerConfig, MaskFormerSwinConfig + from .models.mask2former import ( + Mask2FormerConfig, + ) + from .models.maskformer import ( + MaskFormerConfig, + MaskFormerSwinConfig, + ) from .models.mbart import MBartConfig - from .models.megatron_bert import MegatronBertConfig - from .models.mgp_str import MgpstrConfig, MgpstrProcessor, MgpstrTokenizer - from .models.mimi import MimiConfig + from .models.megatron_bert import ( + MegatronBertConfig, + ) + from .models.mgp_str import ( + MgpstrConfig, + MgpstrProcessor, + MgpstrTokenizer, + ) + from .models.mimi import ( + MimiConfig, + ) from .models.mistral import MistralConfig from .models.mixtral import MixtralConfig - from .models.mllama import MllamaConfig, MllamaProcessor - from .models.mobilebert import MobileBertConfig, MobileBertTokenizer - from .models.mobilenet_v1 import MobileNetV1Config - from .models.mobilenet_v2 import MobileNetV2Config - from .models.mobilevit import MobileViTConfig - from .models.mobilevitv2 import MobileViTV2Config - from .models.moshi import MoshiConfig, MoshiDepthConfig - from .models.mpnet import MPNetConfig, MPNetTokenizer + from .models.mllama import ( + MllamaConfig, + MllamaProcessor, + ) + from .models.mobilebert import ( + MobileBertConfig, + MobileBertTokenizer, + ) + from .models.mobilenet_v1 import ( + MobileNetV1Config, + ) + from .models.mobilenet_v2 import ( + MobileNetV2Config, + ) + from .models.mobilevit import ( + MobileViTConfig, + ) + from .models.mobilevitv2 import ( + MobileViTV2Config, + ) + from .models.moshi import ( + MoshiConfig, + MoshiDepthConfig, + ) + from .models.mpnet import ( + MPNetConfig, + MPNetTokenizer, + ) from .models.mpt import MptConfig from .models.mra import MraConfig from .models.mt5 import MT5Config - from .models.musicgen import MusicgenConfig, MusicgenDecoderConfig + from .models.musicgen import ( + MusicgenConfig, + MusicgenDecoderConfig, + ) from .models.musicgen_melody import ( MusicgenMelodyConfig, MusicgenMelodyDecoderConfig, @@ -5409,12 +5511,23 @@ from .models.nemotron import NemotronConfig from .models.nllb_moe import NllbMoeConfig from .models.nougat import NougatProcessor - from .models.nystromformer import NystromformerConfig + from .models.nystromformer import ( + NystromformerConfig, + ) from .models.olmo import OlmoConfig from .models.olmoe import OlmoeConfig - from .models.omdet_turbo import OmDetTurboConfig, OmDetTurboProcessor - from .models.oneformer import OneFormerConfig, OneFormerProcessor - from .models.openai import OpenAIGPTConfig, OpenAIGPTTokenizer + from .models.omdet_turbo import ( + OmDetTurboConfig, + OmDetTurboProcessor, + ) + from .models.oneformer import ( + OneFormerConfig, + OneFormerProcessor, + ) + from .models.openai import ( + OpenAIGPTConfig, + OpenAIGPTTokenizer, + ) from .models.opt import OPTConfig from .models.owlv2 import ( Owlv2Config, @@ -5428,13 +5541,27 @@ OwlViTTextConfig, OwlViTVisionConfig, ) - from .models.paligemma import PaliGemmaConfig - from .models.patchtsmixer import PatchTSMixerConfig + from .models.paligemma import ( + PaliGemmaConfig, + ) + from .models.patchtsmixer import ( + PatchTSMixerConfig, + ) from .models.patchtst import PatchTSTConfig - from .models.pegasus import PegasusConfig, PegasusTokenizer - from .models.pegasus_x import PegasusXConfig - from .models.perceiver import PerceiverConfig, PerceiverTokenizer - from .models.persimmon import PersimmonConfig + from .models.pegasus import ( + PegasusConfig, + PegasusTokenizer, + ) + from .models.pegasus_x import ( + PegasusXConfig, + ) + from .models.perceiver import ( + PerceiverConfig, + PerceiverTokenizer, + ) + from .models.persimmon import ( + PersimmonConfig, + ) from .models.phi import PhiConfig from .models.phi3 import Phi3Config from .models.phimoe import PhimoeConfig @@ -5445,11 +5572,21 @@ Pix2StructTextConfig, Pix2StructVisionConfig, ) - from .models.pixtral import PixtralProcessor, PixtralVisionConfig + from .models.pixtral import ( + PixtralProcessor, + PixtralVisionConfig, + ) from .models.plbart import PLBartConfig - from .models.poolformer import PoolFormerConfig - from .models.pop2piano import Pop2PianoConfig - from .models.prophetnet import ProphetNetConfig, ProphetNetTokenizer + from .models.poolformer import ( + PoolFormerConfig, + ) + from .models.pop2piano import ( + Pop2PianoConfig, + ) + from .models.prophetnet import ( + ProphetNetConfig, + ProphetNetTokenizer, + ) from .models.pvt import PvtConfig from .models.pvt_v2 import PvtV2Config from .models.qwen2 import Qwen2Config, Qwen2Tokenizer @@ -5459,18 +5596,35 @@ Qwen2AudioProcessor, ) from .models.qwen2_moe import Qwen2MoeConfig - from .models.qwen2_vl import Qwen2VLConfig, Qwen2VLProcessor + from .models.qwen2_vl import ( + Qwen2VLConfig, + Qwen2VLProcessor, + ) from .models.rag import RagConfig, RagRetriever, RagTokenizer from .models.recurrent_gemma import RecurrentGemmaConfig from .models.reformer import ReformerConfig from .models.regnet import RegNetConfig from .models.rembert import RemBertConfig from .models.resnet import ResNetConfig - from .models.roberta import RobertaConfig, RobertaTokenizer - from .models.roberta_prelayernorm import RobertaPreLayerNormConfig - from .models.roc_bert import RoCBertConfig, RoCBertTokenizer - from .models.roformer import RoFormerConfig, RoFormerTokenizer - from .models.rt_detr import RTDetrConfig, RTDetrResNetConfig + from .models.roberta import ( + RobertaConfig, + RobertaTokenizer, + ) + from .models.roberta_prelayernorm import ( + RobertaPreLayerNormConfig, + ) + from .models.roc_bert import ( + RoCBertConfig, + RoCBertTokenizer, + ) + from .models.roformer import ( + RoFormerConfig, + RoFormerTokenizer, + ) + from .models.rt_detr import ( + RTDetrConfig, + RTDetrResNetConfig, + ) from .models.rwkv import RwkvConfig from .models.sam import ( SamConfig, @@ -5484,7 +5638,9 @@ SeamlessM4TFeatureExtractor, SeamlessM4TProcessor, ) - from .models.seamless_m4t_v2 import SeamlessM4Tv2Config + from .models.seamless_m4t_v2 import ( + SeamlessM4Tv2Config, + ) from .models.segformer import SegformerConfig from .models.seggpt import SegGptConfig from .models.sew import SEWConfig @@ -5507,29 +5663,61 @@ SpeechT5HifiGanConfig, SpeechT5Processor, ) - from .models.splinter import SplinterConfig, SplinterTokenizer - from .models.squeezebert import SqueezeBertConfig, SqueezeBertTokenizer + from .models.splinter import ( + SplinterConfig, + SplinterTokenizer, + ) + from .models.squeezebert import ( + SqueezeBertConfig, + SqueezeBertTokenizer, + ) from .models.stablelm import StableLmConfig from .models.starcoder2 import Starcoder2Config from .models.superpoint import SuperPointConfig - from .models.swiftformer import SwiftFormerConfig + from .models.swiftformer import ( + SwiftFormerConfig, + ) from .models.swin import SwinConfig from .models.swin2sr import Swin2SRConfig from .models.swinv2 import Swinv2Config - from .models.switch_transformers import SwitchTransformersConfig + from .models.switch_transformers import ( + SwitchTransformersConfig, + ) from .models.t5 import T5Config - from .models.table_transformer import TableTransformerConfig - from .models.tapas import TapasConfig, TapasTokenizer - from .models.time_series_transformer import TimeSeriesTransformerConfig - from .models.timesformer import TimesformerConfig + from .models.table_transformer import ( + TableTransformerConfig, + ) + from .models.tapas import ( + TapasConfig, + TapasTokenizer, + ) + from .models.time_series_transformer import ( + TimeSeriesTransformerConfig, + ) + from .models.timesformer import ( + TimesformerConfig, + ) from .models.timm_backbone import TimmBackboneConfig - from .models.trocr import TrOCRConfig, TrOCRProcessor - from .models.tvp import TvpConfig, TvpProcessor + from .models.trocr import ( + TrOCRConfig, + TrOCRProcessor, + ) + from .models.tvp import ( + TvpConfig, + TvpProcessor, + ) from .models.udop import UdopConfig, UdopProcessor from .models.umt5 import UMT5Config - from .models.unispeech import UniSpeechConfig - from .models.unispeech_sat import UniSpeechSatConfig - from .models.univnet import UnivNetConfig, UnivNetFeatureExtractor + from .models.unispeech import ( + UniSpeechConfig, + ) + from .models.unispeech_sat import ( + UniSpeechSatConfig, + ) + from .models.univnet import ( + UnivNetConfig, + UnivNetFeatureExtractor, + ) from .models.upernet import UperNetConfig from .models.video_llava import VideoLlavaConfig from .models.videomae import VideoMAEConfig @@ -5539,19 +5727,26 @@ ViltImageProcessor, ViltProcessor, ) - from .models.vipllava import VipLlavaConfig + from .models.vipllava import ( + VipLlavaConfig, + ) from .models.vision_encoder_decoder import VisionEncoderDecoderConfig from .models.vision_text_dual_encoder import ( VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor, ) - from .models.visual_bert import VisualBertConfig + from .models.visual_bert import ( + VisualBertConfig, + ) from .models.vit import ViTConfig from .models.vit_mae import ViTMAEConfig from .models.vit_msn import ViTMSNConfig from .models.vitdet import VitDetConfig from .models.vitmatte import VitMatteConfig - from .models.vits import VitsConfig, VitsTokenizer + from .models.vits import ( + VitsConfig, + VitsTokenizer, + ) from .models.vivit import VivitConfig from .models.wav2vec2 import ( Wav2Vec2Config, @@ -5560,8 +5755,13 @@ Wav2Vec2Processor, Wav2Vec2Tokenizer, ) - from .models.wav2vec2_bert import Wav2Vec2BertConfig, Wav2Vec2BertProcessor - from .models.wav2vec2_conformer import Wav2Vec2ConformerConfig + from .models.wav2vec2_bert import ( + Wav2Vec2BertConfig, + Wav2Vec2BertProcessor, + ) + from .models.wav2vec2_conformer import ( + Wav2Vec2ConformerConfig, + ) from .models.wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizer from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM from .models.wavlm import WavLMConfig @@ -5579,8 +5779,12 @@ ) from .models.xglm import XGLMConfig from .models.xlm import XLMConfig, XLMTokenizer - from .models.xlm_roberta import XLMRobertaConfig - from .models.xlm_roberta_xl import XLMRobertaXLConfig + from .models.xlm_roberta import ( + XLMRobertaConfig, + ) + from .models.xlm_roberta_xl import ( + XLMRobertaXLConfig, + ) from .models.xlnet import XLNetConfig from .models.xmod import XmodConfig from .models.yolos import YolosConfig @@ -5892,26 +6096,17 @@ ConditionalDetrFeatureExtractor, ConditionalDetrImageProcessor, ) - from .models.convnext import ( - ConvNextFeatureExtractor, - ConvNextImageProcessor, - ) + from .models.convnext import ConvNextFeatureExtractor, ConvNextImageProcessor from .models.deformable_detr import ( DeformableDetrFeatureExtractor, DeformableDetrImageProcessor, ) from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor from .models.deprecated.deta import DetaImageProcessor - from .models.deprecated.efficientformer import ( - EfficientFormerImageProcessor, - ) + from .models.deprecated.efficientformer import EfficientFormerImageProcessor from .models.deprecated.tvlt import TvltImageProcessor from .models.deprecated.vit_hybrid import ViTHybridImageProcessor - from .models.detr import ( - DetrFeatureExtractor, - DetrImageProcessor, - DetrImageProcessorFast, - ) + from .models.detr import DetrFeatureExtractor, DetrImageProcessor, DetrImageProcessorFast from .models.donut import DonutFeatureExtractor, DonutImageProcessor from .models.dpt import DPTFeatureExtractor, DPTImageProcessor from .models.efficientnet import EfficientNetImageProcessor @@ -5926,10 +6121,7 @@ from .models.idefics import IdeficsImageProcessor from .models.idefics2 import Idefics2ImageProcessor from .models.idefics3 import Idefics3ImageProcessor - from .models.imagegpt import ( - ImageGPTFeatureExtractor, - ImageGPTImageProcessor, - ) + from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor from .models.instructblipvideo import InstructBlipVideoImageProcessor from .models.layoutlmv2 import ( LayoutLMv2FeatureExtractor, @@ -5942,10 +6134,7 @@ from .models.levit import LevitFeatureExtractor, LevitImageProcessor from .models.llava_next import LlavaNextImageProcessor from .models.llava_next_video import LlavaNextVideoImageProcessor - from .models.llava_onevision import ( - LlavaOnevisionImageProcessor, - LlavaOnevisionVideoProcessor, - ) + from .models.llava_onevision import LlavaOnevisionImageProcessor, LlavaOnevisionVideoProcessor from .models.mask2former import Mask2FormerImageProcessor from .models.maskformer import ( MaskFormerFeatureExtractor, @@ -5960,18 +6149,12 @@ MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor, ) - from .models.mobilevit import ( - MobileViTFeatureExtractor, - MobileViTImageProcessor, - ) + from .models.mobilevit import MobileViTFeatureExtractor, MobileViTImageProcessor from .models.nougat import NougatImageProcessor from .models.oneformer import OneFormerImageProcessor from .models.owlv2 import Owlv2ImageProcessor from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor - from .models.perceiver import ( - PerceiverFeatureExtractor, - PerceiverImageProcessor, - ) + from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor from .models.pix2struct import Pix2StructImageProcessor from .models.pixtral import PixtralImageProcessor from .models.poolformer import ( @@ -5980,30 +6163,17 @@ ) from .models.pvt import PvtImageProcessor from .models.qwen2_vl import Qwen2VLImageProcessor - from .models.rt_detr import ( - RTDetrImageProcessor, - RTDetrImageProcessorFast, - ) + from .models.rt_detr import RTDetrImageProcessor, RTDetrImageProcessorFast from .models.sam import SamImageProcessor - from .models.segformer import ( - SegformerFeatureExtractor, - SegformerImageProcessor, - ) + from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.seggpt import SegGptImageProcessor from .models.siglip import SiglipImageProcessor from .models.superpoint import SuperPointImageProcessor from .models.swin2sr import Swin2SRImageProcessor from .models.tvp import TvpImageProcessor from .models.video_llava import VideoLlavaImageProcessor - from .models.videomae import ( - VideoMAEFeatureExtractor, - VideoMAEImageProcessor, - ) - from .models.vilt import ( - ViltFeatureExtractor, - ViltImageProcessor, - ViltProcessor, - ) + from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor + from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor from .models.vit import ViTFeatureExtractor, ViTImageProcessor from .models.vitmatte import VitMatteImageProcessor from .models.vivit import VivitImageProcessor @@ -6485,7 +6655,10 @@ CvtModel, CvtPreTrainedModel, ) - from .models.dac import DacModel, DacPreTrainedModel + from .models.dac import ( + DacModel, + DacPreTrainedModel, + ) from .models.data2vec import ( Data2VecAudioForAudioFrameClassification, Data2VecAudioForCTC, @@ -6724,7 +6897,10 @@ DistilBertModel, DistilBertPreTrainedModel, ) - from .models.donut import DonutSwinModel, DonutSwinPreTrainedModel + from .models.donut import ( + DonutSwinModel, + DonutSwinPreTrainedModel, + ) from .models.dpr import ( DPRContextEncoder, DPRPretrainedContextEncoder, @@ -6757,7 +6933,10 @@ ElectraPreTrainedModel, load_tf_weights_in_electra, ) - from .models.encodec import EncodecModel, EncodecPreTrainedModel + from .models.encodec import ( + EncodecModel, + EncodecPreTrainedModel, + ) from .models.encoder_decoder import EncoderDecoderModel from .models.ernie import ( ErnieForCausalLM, @@ -6853,7 +7032,10 @@ FunnelPreTrainedModel, load_tf_weights_in_funnel, ) - from .models.fuyu import FuyuForCausalLM, FuyuPreTrainedModel + from .models.fuyu import ( + FuyuForCausalLM, + FuyuPreTrainedModel, + ) from .models.gemma import ( GemmaForCausalLM, GemmaForSequenceClassification, @@ -7155,12 +7337,7 @@ Mamba2Model, Mamba2PreTrainedModel, ) - from .models.marian import ( - MarianForCausalLM, - MarianModel, - MarianMTModel, - MarianPreTrainedModel, - ) + from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel, MarianPreTrainedModel from .models.markuplm import ( MarkupLMForQuestionAnswering, MarkupLMForSequenceClassification, @@ -7204,7 +7381,10 @@ MgpstrModel, MgpstrPreTrainedModel, ) - from .models.mimi import MimiModel, MimiPreTrainedModel + from .models.mimi import ( + MimiModel, + MimiPreTrainedModel, + ) from .models.mistral import ( MistralForCausalLM, MistralForQuestionAnswering, @@ -7476,7 +7656,10 @@ Pix2StructTextModel, Pix2StructVisionModel, ) - from .models.pixtral import PixtralPreTrainedModel, PixtralVisionModel + from .models.pixtral import ( + PixtralPreTrainedModel, + PixtralVisionModel, + ) from .models.plbart import ( PLBartForCausalLM, PLBartForConditionalGeneration, @@ -7634,7 +7817,10 @@ RwkvModel, RwkvPreTrainedModel, ) - from .models.sam import SamModel, SamPreTrainedModel + from .models.sam import ( + SamModel, + SamPreTrainedModel, + ) from .models.seamless_m4t import ( SeamlessM4TCodeHifiGan, SeamlessM4TForSpeechToSpeech, @@ -7799,7 +7985,10 @@ TimesformerPreTrainedModel, ) from .models.timm_backbone import TimmBackbone - from .models.trocr import TrOCRForCausalLM, TrOCRPreTrainedModel + from .models.trocr import ( + TrOCRForCausalLM, + TrOCRPreTrainedModel, + ) from .models.tvp import ( TvpForVideoGrounding, TvpModel, @@ -7901,7 +8090,10 @@ VitMatteForImageMatting, VitMattePreTrainedModel, ) - from .models.vits import VitsModel, VitsPreTrainedModel + from .models.vits import ( + VitsModel, + VitsPreTrainedModel, + ) from .models.vivit import ( VivitForVideoClassification, VivitModel, @@ -8050,11 +8242,7 @@ get_scheduler, get_wsd_schedule, ) - from .pytorch_utils import ( - Conv1D, - apply_chunking_to_forward, - prune_layer, - ) + from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer # Trainer from .trainer import Trainer @@ -8467,11 +8655,7 @@ TFOpenAIGPTModel, TFOpenAIGPTPreTrainedModel, ) - from .models.opt import ( - TFOPTForCausalLM, - TFOPTModel, - TFOPTPreTrainedModel, - ) + from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel from .models.pegasus import ( TFPegasusForConditionalGeneration, TFPegasusModel, @@ -8535,7 +8719,10 @@ TFRoFormerModel, TFRoFormerPreTrainedModel, ) - from .models.sam import TFSamModel, TFSamPreTrainedModel + from .models.sam import ( + TFSamModel, + TFSamPreTrainedModel, + ) from .models.segformer import ( TFSegformerDecodeHead, TFSegformerForImageClassification, @@ -8573,9 +8760,7 @@ TFTapasPreTrainedModel, ) from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel - from .models.vision_text_dual_encoder import ( - TFVisionTextDualEncoderModel, - ) + from .models.vision_text_dual_encoder import TFVisionTextDualEncoderModel from .models.vit import ( TFViTForImageClassification, TFViTModel, @@ -8665,10 +8850,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_torchaudio_objects import * else: - from .models.musicgen_melody import ( - MusicgenMelodyFeatureExtractor, - MusicgenMelodyProcessor, - ) + from .models.musicgen_melody import MusicgenMelodyFeatureExtractor, MusicgenMelodyProcessor try: if not is_flax_available(): raise OptionalDependencyNotAvailable() @@ -8875,11 +9057,7 @@ FlaxMT5ForConditionalGeneration, FlaxMT5Model, ) - from .models.opt import ( - FlaxOPTForCausalLM, - FlaxOPTModel, - FlaxOPTPreTrainedModel, - ) + from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.pegasus import ( FlaxPegasusForConditionalGeneration, FlaxPegasusModel, @@ -8924,21 +9102,15 @@ FlaxRoFormerModel, FlaxRoFormerPreTrainedModel, ) - from .models.speech_encoder_decoder import ( - FlaxSpeechEncoderDecoderModel, - ) + from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel from .models.t5 import ( FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel, ) - from .models.vision_encoder_decoder import ( - FlaxVisionEncoderDecoderModel, - ) - from .models.vision_text_dual_encoder import ( - FlaxVisionTextDualEncoderModel, - ) + from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel + from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel from .models.vit import ( FlaxViTForImageClassification, FlaxViTModel, From 2ea53ebabf008feed89022c7db07362fc4d3cb01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 12:22:21 +0100 Subject: [PATCH 28/44] add usage instruction snippet to docs --- docs/source/en/model_doc/ijepa.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md index c966fe6b02aa..efa79b43400b 100644 --- a/docs/source/en/model_doc/ijepa.md +++ b/docs/source/en/model_doc/ijepa.md @@ -28,6 +28,30 @@ This paper demonstrates an approach for learning highly semantic image represent This model was contributed by [jmtzt](https://huggingface.co/jmtzt). The original code can be found [here](https://github.com/facebookresearch/ijepa). +## How to use + +Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes: + +```python +import requests + +from PIL import Image +from transformers import AutoProcessor, IJepaForImageClassification + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = Image.open(requests.get(url, stream=True).raw) + +model_id = "jmtzt/ijepa_vith14_1k" +processor = AutoProcessor.from_pretrained(model_id) +model = IJepaForImageClassification.from_pretrained(model_id) + +inputs = processor(images=image, return_tensors="pt") +outputs = model(**inputs) +logits = outputs.logits +# model predicts one of the 1000 ImageNet classes +predicted_class_idx = logits.argmax(-1).item() +print("Predicted class:", model.config.id2label[predicted_class_idx]) +``` ## IJepaConfig From 13ccd82765659a68bfe4013781ef154678083ecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 12:22:48 +0100 Subject: [PATCH 29/44] change pos encoding, add checkpoint for doc --- .../models/ijepa/modeling_ijepa.py | 61 +++++++++---------- .../models/ijepa/modular_ijepa.py | 58 ++++++++++-------- 2 files changed, 60 insertions(+), 59 deletions(-) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 0d83949309c0..e346b764686d 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -16,14 +16,20 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + torch_int, +) from .configuration_ijepa import IJepaConfig logger = logging.get_logger(__name__) -# Base docstring -_CHECKPOINT_FOR_DOC = "google/ijepa-base-patch16-224-in21k" + +_CHECKPOINT_FOR_DOC = "facebook/ijepa_vith14_1k" # General docstring _CONFIG_FOR_DOC = "IJepaConfig" @@ -52,44 +58,36 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 - - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ - num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] - patch_pos_embed = self.position_embeddings[:, 1:] + + patch_pos_embed = self.position_embeddings + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape( - 1, - int(math.sqrt(num_positions)), - int(math.sqrt(num_positions)), - dim, - ) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=( - h0 / math.sqrt(num_positions), - w0 / math.sqrt(num_positions), - ), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return patch_pos_embed def forward( self, @@ -667,11 +665,10 @@ def __init__(self, config: IJepaConfig) -> None: super().__init__(config) self.num_labels = config.num_labels - self.ijepa = IJepaModel(config, add_pooling_layer=False) + self.ijepa = IJepaModel(config) # Classifier head self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() - self.vit = IJepaModel(config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index 1870524d38a4..d8bd5f332f6b 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -1,4 +1,3 @@ -import math from typing import Optional, Union import torch @@ -7,6 +6,9 @@ from transformers.models.ijepa.configuration_ijepa import IJepaConfig from ...modeling_utils import PreTrainedModel +from ...utils import ( + torch_int, +) from ..vit.modeling_vit import ( ViTAttention, ViTEmbeddings, @@ -24,6 +26,9 @@ ) +_CHECKPOINT_FOR_DOC = "facebook/ijepa_vith14_1k" + + class IJepaEmbeddings(ViTEmbeddings): def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None: super().__init__(config, use_mask_token) @@ -34,43 +39,42 @@ def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None: def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ - num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] - patch_pos_embed = self.position_embeddings[:, 1:] + + patch_pos_embed = self.position_embeddings + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape( - 1, - int(math.sqrt(num_positions)), - int(math.sqrt(num_positions)), - dim, - ) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=( - h0 / math.sqrt(num_positions), - w0 / math.sqrt(num_positions), - ), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return patch_pos_embed def forward( self, @@ -194,7 +198,7 @@ class IJepaPooler(ViTPooler): class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification): def __init__(self, config: IJepaConfig): super().__init__(config) - self.vit = IJepaModel(config) + self.ijepa = IJepaModel(config) self.post_init() From 10cbda2a760525a2560c31197657bd424c7872f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 12:23:17 +0100 Subject: [PATCH 30/44] add verify logits for all models --- .../models/ijepa/convert_ijepa_to_hf.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/ijepa/convert_ijepa_to_hf.py b/src/transformers/models/ijepa/convert_ijepa_to_hf.py index c2ed1611837c..5c15a72ff888 100644 --- a/src/transformers/models/ijepa/convert_ijepa_to_hf.py +++ b/src/transformers/models/ijepa/convert_ijepa_to_hf.py @@ -109,12 +109,6 @@ def read_in_q_k_v(state_dict, config): state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] -def remove_classification_head_(state_dict): - ignore_keys = ["head.weight", "head.bias"] - for k in ignore_keys: - state_dict.pop(k, None) - - def rename_key(dct, old, new): val = dct.pop(old) dct[new] = val @@ -182,21 +176,33 @@ def write_model(model_name, output_dir, safe_serialization, push_to_hub, verify_ # load HuggingFace model model = IJepaModel(config, add_pooling_layer=False).eval() model.load_state_dict(state_dict) + size = {"height": config.image_size, "width": config.image_size} + image_processor = ViTImageProcessor(size=size) if verify_logits: # Check outputs on an image, prepared by ViTImageProcessor - image_processor = ViTImageProcessor() encoding = image_processor(images=prepare_img(), return_tensors="pt") pixel_values = encoding["pixel_values"] with torch.no_grad(): outputs = model(pixel_values) - expected_slice = torch.Tensor( - [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]] - ) + expected_slices = { + "ijepa_vith14_1k": torch.Tensor( + [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]] + ), + "ijepa_vith14_22k": torch.Tensor( + [[0.0358, -0.0045, -0.2154], [0.0418, -0.0246, 0.0108], [0.2529, -0.0345, -0.0246]] + ), + "ijepa_vith16_1k": torch.Tensor( + [[0.5145, -0.1259, 0.0615], [0.1132, 0.0028, -0.0496], [1.1586, -0.0056, -0.0387]] + ), + "ijepa_vitg16_22k": torch.Tensor( + [[0.0512, -0.0510, -0.0649], [0.1972, 0.0380, -0.0790], [0.1667, -0.0834, -0.1240]] + ), + } assert torch.allclose( - expected_slice, + expected_slices[model_name], outputs.last_hidden_state[0, :3, :3], atol=1e-4, ) @@ -204,17 +210,12 @@ def write_model(model_name, output_dir, safe_serialization, push_to_hub, verify_ if output_dir: Path(output_dir).mkdir(exist_ok=True) print(f"Saving model {model_name} to {output_dir}") + image_processor.save_pretrained(output_dir, safe_serialization=safe_serialization) model.save_pretrained(output_dir, safe_serialization=safe_serialization) if push_to_hub: - model_name_to_hf_name = { - "ijepa_vith14_1k": "ijepa_huge_patch14_1k", - "ijepa_vith14_22k": "ijepa_huge_patch14_22k", - "ijepa_vith16_1k": "ijepa_huge_patch16_1k", - "ijepa_vitg16_22k": "ijepa_giant_patch16_22k", - } - name = model_name_to_hf_name[model_name] - model.push_to_hub(f"jmtzt/{name}", use_temp_dir=True) + image_processor.push_to_hub(repo_id=f"jmtzt/{model_name}", safe_serialization=safe_serialization) + model.push_to_hub(repo_id=f"jmtzt/{model_name}", safe_serialization=safe_serialization) if output_dir: del model, state_dict From 0ccd96e555328c69f2c74443eaa935187b08d095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 12:30:39 +0100 Subject: [PATCH 31/44] [run-slow] ijepa From d2d47d4c1ab5091eb4a2314b8308756c004e2bb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 14:33:21 +0100 Subject: [PATCH 32/44] update docs to include image feature extraction instructions --- docs/source/en/model_doc/ijepa.md | 35 +++++++++++++++++++------------ 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md index efa79b43400b..2191360d241e 100644 --- a/docs/source/en/model_doc/ijepa.md +++ b/docs/source/en/model_doc/ijepa.md @@ -30,27 +30,36 @@ The original code can be found [here](https://github.com/facebookresearch/ijepa) ## How to use -Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes: +Here is how to use this model for image feature extraction: ```python import requests - from PIL import Image -from transformers import AutoProcessor, IJepaForImageClassification +from torch.nn.functional import cosine_similarity + +from transformers import AutoModel, AutoProcessor -url = "http://images.cocodataset.org/val2017/000000039769.jpg" -image = Image.open(requests.get(url, stream=True).raw) +url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg" +url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg" +image_1 = Image.open(requests.get(url_1, stream=True).raw) +image_2 = Image.open(requests.get(url_2, stream=True).raw) model_id = "jmtzt/ijepa_vith14_1k" processor = AutoProcessor.from_pretrained(model_id) -model = IJepaForImageClassification.from_pretrained(model_id) - -inputs = processor(images=image, return_tensors="pt") -outputs = model(**inputs) -logits = outputs.logits -# model predicts one of the 1000 ImageNet classes -predicted_class_idx = logits.argmax(-1).item() -print("Predicted class:", model.config.id2label[predicted_class_idx]) +model = AutoModel.from_pretrained(model_id) + + +def infer(image): + inputs = processor(image, return_tensors="pt") + outputs = model(**inputs) + return outputs.pooler_output + + +embed_1 = infer(image_1) +embed_2 = infer(image_2) + +similarity = cosine_similarity(embed_1, embed_2) +print(similarity) ``` ## IJepaConfig From 8e8df55ddf90d5748d8546b5f462d0628b69900d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 14:34:05 +0100 Subject: [PATCH 33/44] remove pooling layer from IJepaModel in image classification class --- src/transformers/models/ijepa/modeling_ijepa.py | 2 +- src/transformers/models/ijepa/modular_ijepa.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index e346b764686d..2234d7d8d2d4 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -665,7 +665,7 @@ def __init__(self, config: IJepaConfig) -> None: super().__init__(config) self.num_labels = config.num_labels - self.ijepa = IJepaModel(config) + self.ijepa = IJepaModel(config, add_pooling_layer=False) # Classifier head self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index d8bd5f332f6b..cf44af374d48 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -198,7 +198,7 @@ class IJepaPooler(ViTPooler): class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification): def __init__(self, config: IJepaConfig): super().__init__(config) - self.ijepa = IJepaModel(config) + self.ijepa = IJepaModel(config, add_pooling_layer=False) self.post_init() From 50f93d49bc0ff4ba599305034231834deb9b818a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 14:34:12 +0100 Subject: [PATCH 34/44] [run-slow] ijepa From db79009643294a7e7af4365c0d703da9c71f9d99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Mon, 18 Nov 2024 21:17:53 +0100 Subject: [PATCH 35/44] remove pooling layer from IJepaModel constructor --- src/transformers/models/ijepa/modeling_ijepa.py | 2 +- src/transformers/models/ijepa/modular_ijepa.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 2234d7d8d2d4..13d428d5110f 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -533,7 +533,7 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No IJEPA_START_DOCSTRING, ) class IJepaModel(IJepaPreTrainedModel): - def __init__(self, config: IJepaConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False): super().__init__(config) self.config = config self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index cf44af374d48..38a1f601e1bb 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -180,7 +180,7 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No class IJepaModel(IJepaPreTrainedModel, ViTModel): - def __init__(self, config: IJepaConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False): super().__init__(config) self.config = config self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) From 57e5407b28847d6f62802643d50de422fa3d4055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 19 Nov 2024 09:45:19 +0100 Subject: [PATCH 36/44] update docs --- docs/source/en/model_doc/ijepa.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md index 2191360d241e..9488584ec32a 100644 --- a/docs/source/en/model_doc/ijepa.md +++ b/docs/source/en/model_doc/ijepa.md @@ -52,7 +52,7 @@ model = AutoModel.from_pretrained(model_id) def infer(image): inputs = processor(image, return_tensors="pt") outputs = model(**inputs) - return outputs.pooler_output + return outputs.last_hidden_state.mean(dim=1) embed_1 = infer(image_1) From 8236816759d3924dd5171ef7ce1999e8f6ce30af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 19 Nov 2024 09:45:25 +0100 Subject: [PATCH 37/44] [run-slow] ijepa From ce6499f6f41586b1496fbbac591ae750c661e7d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 19 Nov 2024 10:03:07 +0100 Subject: [PATCH 38/44] [run-slow] ijepa From 81a6e6608c09b89d3be666dedc69c5e8a79084b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 19 Nov 2024 15:16:43 +0100 Subject: [PATCH 39/44] small changes --- docs/source/en/model_doc/ijepa.md | 3 +- .../models/ijepa/modeling_ijepa.py | 33 +++--- .../models/ijepa/modular_ijepa.py | 107 ++++++++++++++++++ tests/models/ijepa/test_modeling_ijepa.py | 35 +++++- 4 files changed, 156 insertions(+), 22 deletions(-) diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md index 9488584ec32a..9a0cd368a818 100644 --- a/docs/source/en/model_doc/ijepa.md +++ b/docs/source/en/model_doc/ijepa.md @@ -34,6 +34,7 @@ Here is how to use this model for image feature extraction: ```python import requests +import torch from PIL import Image from torch.nn.functional import cosine_similarity @@ -48,7 +49,7 @@ model_id = "jmtzt/ijepa_vith14_1k" processor = AutoProcessor.from_pretrained(model_id) model = AutoModel.from_pretrained(model_id) - +@torch.no_grad() def infer(image): inputs = processor(image, return_tensors="pt") outputs = model(**inputs) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 13d428d5110f..dc815c6dd8de 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -489,20 +489,6 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No ).to(module.position_embeddings.dtype) -_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] - - -IJEPA_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it - as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`IJepaConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - IJEPA_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): @@ -526,6 +512,19 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + + +IJEPA_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`IJepaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" @add_start_docstrings( @@ -647,8 +646,8 @@ def forward(self, hidden_states): @add_start_docstrings( """ - IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the [CLS] token) e.g. for ImageNet. + IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states) + e.g. for ImageNet. @@ -709,7 +708,7 @@ def forward( sequence_output = outputs[0] - logits = self.classifier(sequence_output[:, 0, :]) + logits = self.classifier(sequence_output.mean(dim=1)) loss = None if labels is not None: diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index 38a1f601e1bb..4be286b27fe2 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -2,11 +2,14 @@ import torch import torch.nn as nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.ijepa.configuration_ijepa import IJepaConfig +from ...modeling_outputs import ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( + add_start_docstrings, torch_int, ) from ..vit.modeling_vit import ( @@ -179,6 +182,24 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No ).to(module.position_embeddings.dtype) +_EXPECTED_OUTPUT_SHAPE = [1, 256, 1280] + +IJEPA_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`IJepaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare IJepa Model transformer outputting raw hidden-states without any specific head on top.", + IJEPA_START_DOCSTRING, +) class IJepaModel(IJepaPreTrainedModel, ViTModel): def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False): super().__init__(config) @@ -195,12 +216,98 @@ class IJepaPooler(ViTPooler): pass +_IMAGE_CLASS_CHECKPOINT = "jmtzt/ijepa_vith14_1k" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" + + +@add_start_docstrings( + """ + IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states) + e.g. for ImageNet. + + + + Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + IJEPA_START_DOCSTRING, +) class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification): def __init__(self, config: IJepaConfig): super().__init__(config) self.ijepa = IJepaModel(config, add_pooling_layer=False) self.post_init() + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ijepa( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output.mean(dim=1)) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + __all__ = [ "IJepaPreTrainedModel", diff --git a/tests/models/ijepa/test_modeling_ijepa.py b/tests/models/ijepa/test_modeling_ijepa.py index bde91cff7b6e..27a79bc67242 100644 --- a/tests/models/ijepa/test_modeling_ijepa.py +++ b/tests/models/ijepa/test_modeling_ijepa.py @@ -250,7 +250,7 @@ def test_for_image_classification(self): @slow def test_model_from_pretrained(self): - model_name = "jmtzt/ijepa_huge_patch14_1k" + model_name = "jmtzt/ijepa_vith14_1k" model = IJepaModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -266,11 +266,11 @@ def prepare_img(): class IJepaModelIntegrationTest(unittest.TestCase): @cached_property def default_image_processor(self): - return ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None + return ViTImageProcessor.from_pretrained("jmtzt/ijepa_vith14_1k") if is_vision_available() else None @slow def test_inference_no_head(self): - model = IJepaModel.from_pretrained("jmtzt/ijepa_huge_patch14_1k").to(torch_device) + model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device) image_processor = self.default_image_processor image = prepare_img() @@ -299,7 +299,7 @@ def test_inference_fp16(self): A small test to make sure that inference work in half precision without any problem. """ model = IJepaModel.from_pretrained( - "jmtzt/ijepa_huge_patch14_1k", + "jmtzt/ijepa_vith14_1k", torch_dtype=torch.float16, device_map="auto", ) @@ -312,3 +312,30 @@ def test_inference_fp16(self): # forward pass to make sure inference works in fp16 with torch.no_grad(): _ = model(pixel_values) + + @slow + def test_inference_interpolate_pos_encoding(self): + # I-JEPA, similar to ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 256, 1280)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) From 7a0fc390696f406cddfde68f375c8f620669654b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 19 Nov 2024 15:25:42 +0100 Subject: [PATCH 40/44] [run-slow] ijepa From 37a38f9c3592c679b43075068215f389477ba132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 26 Nov 2024 16:33:14 +0100 Subject: [PATCH 41/44] style adjustments --- .../models/auto/configuration_auto.py | 21 +++------------ .../models/auto/image_processing_auto.py | 19 +++---------- src/transformers/models/auto/modeling_auto.py | 27 +++++-------------- 3 files changed, 15 insertions(+), 52 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9ec932fe7eec..cb45149671b1 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -22,10 +22,7 @@ from typing import List, Union from ...configuration_utils import PretrainedConfig -from ...dynamic_module_utils import ( - get_class_from_dynamic_module, - resolve_trust_remote_code, -) +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...utils import CONFIG_NAME, logging @@ -882,11 +879,7 @@ def docstring_decorator(fn): indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0] if use_model_types: indent = f"{indent} " - lines[i] = _list_model_options( - indent, - config_to_class=config_to_class, - use_model_types=use_model_types, - ) + lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types) docstrings = "\n".join(lines) else: raise ValueError( @@ -1027,19 +1020,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"] has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING trust_remote_code = resolve_trust_remote_code( - trust_remote_code, - pretrained_model_name_or_path, - has_local_code, - has_remote_code, + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code ) if has_remote_code and trust_remote_code: class_ref = config_dict["auto_map"]["AutoConfig"] config_class = get_class_from_dynamic_module( - class_ref, - pretrained_model_name_or_path, - code_revision=code_revision, - **kwargs, + class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs ) if os.path.isdir(pretrained_model_name_or_path): config_class.register_for_auto_class() diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index b7202ff578d9..dfb9df59bd21 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -23,10 +23,7 @@ # Build the list of all image processors from ...configuration_utils import PretrainedConfig -from ...dynamic_module_utils import ( - get_class_from_dynamic_module, - resolve_trust_remote_code, -) +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...image_processing_utils import BaseImageProcessor, ImageProcessingMixin from ...image_processing_utils_fast import BaseImageProcessorFast from ...utils import ( @@ -169,10 +166,7 @@ else: fast_image_processor_class = fast_image_processor_class[0] - IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = ( - slow_image_processor_class, - fast_image_processor_class, - ) + IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_image_processor_class, fast_image_processor_class) IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES) @@ -461,10 +455,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): has_remote_code = image_processor_auto_map is not None has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING trust_remote_code = resolve_trust_remote_code( - trust_remote_code, - pretrained_model_name_or_path, - has_local_code, - has_remote_code, + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code ) if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple): @@ -565,7 +556,5 @@ def register( fast_image_processor_class = existing_fast IMAGE_PROCESSOR_MAPPING.register( - config_class, - (slow_image_processor_class, fast_image_processor_class), - exist_ok=exist_ok, + config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e0cd0c9423a1..32667a0893ea 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -641,10 +641,7 @@ ("data2vec-vision", "Data2VecVisionForImageClassification"), ( "deit", - ( - "DeiTForImageClassification", - "DeiTForImageClassificationWithTeacher", - ), + ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"), ), ("dinat", "DinatForImageClassification"), ("dinov2", "Dinov2ForImageClassification"), @@ -662,10 +659,7 @@ ("imagegpt", "ImageGPTForImageClassification"), ( "levit", - ( - "LevitForImageClassification", - "LevitForImageClassificationWithTeacher", - ), + ("LevitForImageClassification", "LevitForImageClassificationWithTeacher"), ), ("mobilenet_v1", "MobileNetV1ForImageClassification"), ("mobilenet_v2", "MobileNetV2ForImageClassification"), @@ -1000,10 +994,7 @@ ("reformer", "ReformerForSequenceClassification"), ("rembert", "RemBertForSequenceClassification"), ("roberta", "RobertaForSequenceClassification"), - ( - "roberta-prelayernorm", - "RobertaPreLayerNormForSequenceClassification", - ), + ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"), ("roc_bert", "RoCBertForSequenceClassification"), ("roformer", "RoFormerForSequenceClassification"), ("squeezebert", "SqueezeBertForSequenceClassification"), @@ -1449,8 +1440,7 @@ CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ) MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, - MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, + CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES ) MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES @@ -1690,8 +1680,7 @@ class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): AutoModelForZeroShotImageClassification = auto_class_update( - AutoModelForZeroShotImageClassification, - head_doc="zero-shot image classification", + AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" ) @@ -1792,8 +1781,7 @@ class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): AutoModelForSpeechSeq2Seq = auto_class_update( - AutoModelForSpeechSeq2Seq, - head_doc="sequence-to-sequence speech-to-text modeling", + AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" ) @@ -1802,8 +1790,7 @@ class AutoModelForAudioFrameClassification(_BaseAutoModelClass): AutoModelForAudioFrameClassification = auto_class_update( - AutoModelForAudioFrameClassification, - head_doc="audio frame (token) classification", + AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification" ) From 491d5a5c976628a2e19e48085453b381c4b4fd0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 26 Nov 2024 16:33:41 +0100 Subject: [PATCH 42/44] update copyright in init file --- src/transformers/models/ijepa/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/ijepa/__init__.py b/src/transformers/models/ijepa/__init__.py index adaefe9ae43c..efc8c90b1762 100644 --- a/src/transformers/models/ijepa/__init__.py +++ b/src/transformers/models/ijepa/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 2afaba03d242b24003f1e27c5c312e649a56426d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 26 Nov 2024 16:34:00 +0100 Subject: [PATCH 43/44] adjust modular ijepa --- .../models/ijepa/modeling_ijepa.py | 159 +++++++++--------- .../models/ijepa/modular_ijepa.py | 61 ------- 2 files changed, 80 insertions(+), 140 deletions(-) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index dc815c6dd8de..df254455bad5 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -35,6 +35,45 @@ _CONFIG_FOR_DOC = "IJepaConfig" +class IJepaPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + class IJepaEmbeddings(nn.Module): """ Construct the CLS token, position and patch embeddings. Optionally, also the mask token. @@ -116,43 +155,38 @@ def forward( return embeddings -class IJepaPatchEmbeddings(nn.Module): +class IJepaPreTrainedModel(PreTrainedModel): """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. """ - def __init__(self, config): - super().__init__() - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - - self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + config_class = IJepaConfig + base_model_prefix = "ijepa" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] + _supports_sdpa = True - def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - f" Expected {self.num_channels} but got {num_channels}." - ) - if not interpolate_pos_encoding: - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return embeddings + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, IJepaEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) class IJepaSelfAttention(nn.Module): @@ -455,38 +489,19 @@ def forward( ) -class IJepaPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = IJepaConfig - base_model_prefix = "ijepa" - main_input_name = "pixel_values" - supports_gradient_checkpointing = True - _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] - _supports_sdpa = True +class IJepaPooler(nn.Module): + def __init__(self, config: IJepaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() - def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, IJepaEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output IJEPA_INPUTS_DOCSTRING = r""" @@ -537,6 +552,7 @@ def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mas self.config = config self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) self.encoder = IJepaEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pooler = IJepaPooler(config) if add_pooling_layer else None @@ -624,21 +640,6 @@ def forward( ) -class IJepaPooler(nn.Module): - def __init__(self, config: IJepaConfig): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - # Image classification docstring _IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224" _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index 4be286b27fe2..efbd71d91342 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -13,19 +13,9 @@ torch_int, ) from ..vit.modeling_vit import ( - ViTAttention, ViTEmbeddings, - ViTEncoder, ViTForImageClassification, - ViTIntermediate, - ViTLayer, ViTModel, - ViTPatchEmbeddings, - ViTPooler, - ViTSdpaAttention, - ViTSdpaSelfAttention, - ViTSelfAttention, - ViTSelfOutput, ) @@ -106,48 +96,6 @@ def forward( return embeddings -class IJepaPatchEmbeddings(ViTPatchEmbeddings): - pass - - -class IJepaSelfAttention(ViTSelfAttention): - pass - - -class IJepaSdpaSelfAttention(ViTSdpaSelfAttention): - pass - - -class IJepaSelfOutput(ViTSelfOutput): - pass - - -class IJepaAttention(ViTAttention): - pass - - -class IJepaSdpaAttention(ViTSdpaAttention): - pass - - -class IJepaIntermediate(ViTIntermediate): - pass - - -IJepa_ATTENTION_CLASSES = { - "eager": IJepaAttention, - "sdpa": IJepaSdpaAttention, -} - - -class IJepaLayer(ViTLayer): - pass - - -class IJepaEncoder(ViTEncoder): - pass - - class IJepaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -205,15 +153,6 @@ def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mas super().__init__(config) self.config = config self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) - self.encoder = IJepaEncoder(config) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = IJepaPooler(config) if add_pooling_layer else None - # Initialize weights and apply final processing - self.post_init() - - -class IJepaPooler(ViTPooler): - pass _IMAGE_CLASS_CHECKPOINT = "jmtzt/ijepa_vith14_1k" From db4dfc0ee9fd5806cebf2c1c8fcf513994ee88bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Marcelo=20Tozato?= Date: Tue, 26 Nov 2024 16:34:09 +0100 Subject: [PATCH 44/44] [run-slow] ijepa