From 92a22239a854739705cde8d3ba13ad8e5c90463a Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Tue, 7 Jan 2025 20:06:19 +0530 Subject: [PATCH 1/4] Add AIMv2- Updated all files --- docs/source/en/model_doc/aimv2.md | 62 +++ src/transformers/__init__.py | 5 + src/transformers/models/__init__.py | 1 + src/transformers/models/aimv2/__init__.py | 54 +++ .../models/aimv2/configuration_aimv2.py | 99 +++++ .../models/aimv2/convert_aimv2_to_hf.py | 225 +++++++++++ .../models/aimv2/modeling_aimv2.py | 354 ++++++++++++++++++ .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 + tests/models/aimv2/test_modeling_aimv2.py | 293 +++++++++++++++ 10 files changed, 1098 insertions(+) create mode 100644 docs/source/en/model_doc/aimv2.md create mode 100644 src/transformers/models/aimv2/__init__.py create mode 100644 src/transformers/models/aimv2/configuration_aimv2.py create mode 100644 src/transformers/models/aimv2/convert_aimv2_to_hf.py create mode 100644 src/transformers/models/aimv2/modeling_aimv2.py create mode 100644 tests/models/aimv2/test_modeling_aimv2.py diff --git a/docs/source/en/model_doc/aimv2.md b/docs/source/en/model_doc/aimv2.md new file mode 100644 index 000000000000..fc6c084db4da --- /dev/null +++ b/docs/source/en/model_doc/aimv2.md @@ -0,0 +1,62 @@ + + +# AIMV2 + +## Overview + +The AIMV2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, and Alaaeldin El-Nouby. +AIMV2, a family of generalist vision encoders characterized by a straightforward pre-training process, scalability, and remarkable performance across a range of downstream tasks. + +The abstract from the paper is the following: + +*We introduce a novel method for pre-training of large-scale +vision encoders. Building on recent advancements in autoregressive pre-training of vision models, we extend this +framework to a multimodal setting, i.e., images and text. In +this paper, we present AIMV2, a family of generalist vision +encoders characterized by a straightforward pre-training +process, scalability, and remarkable performance across a +range of downstream tasks. This is achieved by pairing the +vision encoder with a multimodal decoder that autoregressively generates raw image patches and text tokens. Our +encoders excel not only in multimodal evaluations but also +in vision benchmarks such as localization, grounding, and +classification. Notably, our AIMV2-3B encoder achieves +89.5% accuracy on ImageNet-1k with a frozen trunk. Furthermore, AIMV2 consistently outperforms state-of-the-art +contrastive models (e.g., CLIP, SigLIP) in multimodal image understanding across diverse settings. +* + +Tips: + +- The model is best suited for fine-tuning on downstream vision tasks such as image classification, object detection, and semantic segmentation. +- When using the model for inference, make sure to use an `AutoImageProcessor` (or manually process the images) to ensure the input images are preprocessed correctly (resized, normalized, etc.). The recommended image size for AIMv2 is typically 224x224, though some variants are trained on other resolutions (e.g., 336x336, 448x448). See the specific model checkpoint's documentation for details. +- AIMv2 models are trained using masked image modeling. If using the model for transfer learning, you may notice better performance by incorporating masked data during fine-tuning. + +This model was contributed by [AlanPonnachan](https://huggingface.co/AlanPonnachan). +The original code can be found [here](https://github.com/apple/ml-aim). + + +## AIMv2Config + +[[autodoc]] AIMv2Config + +## AIMv2Model + +[[autodoc]] AIMv2Model + - forward + + + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5510ac6c8ad5..7966c1feab9d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -5093,6 +5093,7 @@ load_tf2_model_in_pytorch_model, load_tf2_weights_in_pytorch_model, ) + from .models.aimv2 import AIMv2Config from .models.albert import AlbertConfig from .models.align import ( AlignConfig, @@ -6398,6 +6399,10 @@ ) from .modeling_rope_utils import ROPE_INIT_FUNCTIONS from .modeling_utils import PreTrainedModel + from .models.aimv2 import ( + AIMv2Model, + AIMv2PreTrainedModel, + ) from .models.albert import ( AlbertForMaskedLM, AlbertForMultipleChoice, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7fcaddde704c..9b3e80a6e3aa 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from . import ( + aimv2, albert, align, altclip, diff --git a/src/transformers/models/aimv2/__init__.py b/src/transformers/models/aimv2/__init__.py new file mode 100644 index 000000000000..870676629c13 --- /dev/null +++ b/src/transformers/models/aimv2/__init__.py @@ -0,0 +1,54 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_aimv2": ["AIMv2Config"], + "modeling_aimv2": ["AIMv2Model", "AIMv2PreTrainedModel"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_aimv2"] = [ + "AIMv2Model", + "AIMv2PreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_aimv2 import AIMv2Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_aimv2 import AIMv2Model, AIMv2PreTrainedModel + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/aimv2/configuration_aimv2.py b/src/transformers/models/aimv2/configuration_aimv2.py new file mode 100644 index 000000000000..cf7e353831c2 --- /dev/null +++ b/src/transformers/models/aimv2/configuration_aimv2.py @@ -0,0 +1,99 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AIMv2 model configuration""" + +from typing import Any + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +__all__ = ["AIMv2Config"] + + +class AIMv2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of an [`AIMv2Model`]. + Instantiating a configuration with the defaults will yield a similar configuration + to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224) architecture. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 2816): + Dimension of the SwiGLU representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer + in the Transformer. + num_channels (`int`, *optional*, defaults to 3): + Number of input channels. + image_size (`int`, *optional*, defaults to 224): + Image size. + patch_size (`int`, *optional*, defaults to 14): + Patch size. + rms_norm_eps (`float`, *optional*, defaults to 1e-5): + Epsilon value used for the RMS normalization layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout ratio for attention probabilities. + projection_dropout (`float`, *optional*, defaults to 0.0): + Dropout ratio for the projection layer after the attention. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + use_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias in the feed-forward and projection layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + kwargs: + Keyword arguments for the [`PretrainedConfig`]. + """ + + model_type: str = "aimv2" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 2816, + num_hidden_layers: int = 24, + num_attention_heads: int = 8, + num_channels: int = 3, + image_size: int = 224, + patch_size: int = 14, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + projection_dropout: float = 0.0, + qkv_bias: bool = False, + use_bias: bool = False, + initializer_range: float = 0.02, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.rms_norm_eps = rms_norm_eps + self.initializer_range = initializer_range + self.projection_dropout = projection_dropout + self.qkv_bias = qkv_bias + self.use_bias = use_bias diff --git a/src/transformers/models/aimv2/convert_aimv2_to_hf.py b/src/transformers/models/aimv2/convert_aimv2_to_hf.py new file mode 100644 index 000000000000..9a7b652de149 --- /dev/null +++ b/src/transformers/models/aimv2/convert_aimv2_to_hf.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert AIMv2 checkpoints from the original repository. + +URL: https://github.com/apple/ml-aim/tree/main/aim-v2 +""" + +import argparse +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from safetensors import safe_open + +from transformers import AIMv2Config, AIMv2Model, AutoImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_aimv2_config(model_name): + config = AIMv2Config() + + # Use the appropriate hyperparameters depending on the model name. + if "aimv2-base" in model_name: + config.hidden_size = 768 + config.intermediate_size = 2048 + config.num_hidden_layers = 12 + config.num_attention_heads = 6 + elif "aimv2-large" in model_name: + config.hidden_size = 1024 + config.intermediate_size = 2816 + config.num_hidden_layers = 24 + config.num_attention_heads = 8 + elif "aimv2-huge" in model_name: + config.hidden_size = 1536 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 12 + elif "aimv2-1B" in model_name: + config.hidden_size = 2048 + config.intermediate_size = 5632 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + elif "aimv2-3B" in model_name: + config.hidden_size = 3072 + config.intermediate_size = 8192 + config.num_hidden_layers = 24 + config.num_attention_heads = 24 + + return config + + +def create_rename_keys(config): + rename_keys = [] + # fmt: off + + # patch embedding layer + rename_keys.append(("preprocessor.pos_embed", "preprocessor.pos_embed")) + rename_keys.append(("preprocessor.patchifier.proj.weight", "preprocessor.patchifier.proj.weight")) + rename_keys.append(("preprocessor.patchifier.proj.bias", "preprocessor.patchifier.proj.bias")) + rename_keys.append(("preprocessor.patchifier.norm.weight", "preprocessor.patchifier.norm.weight")) + + for i in range(config.num_hidden_layers): + # attention blocks + rename_keys.append((f"trunk.blocks.{i}.attn.qkv.weight", f"trunk.blocks.{i}.attn.qkv.weight")) + rename_keys.append((f"trunk.blocks.{i}.attn.proj.weight", f"trunk.blocks.{i}.attn.proj.weight")) + + # MLP blocks + rename_keys.append((f"trunk.blocks.{i}.norm_1.weight", f"trunk.blocks.{i}.norm_1.weight")) + rename_keys.append((f"trunk.blocks.{i}.mlp.fc1.weight", f"trunk.blocks.{i}.mlp.fc1.weight")) + rename_keys.append((f"trunk.blocks.{i}.mlp.fc2.weight", f"trunk.blocks.{i}.mlp.fc2.weight")) + rename_keys.append((f"trunk.blocks.{i}.mlp.fc3.weight", f"trunk.blocks.{i}.mlp.fc3.weight")) + rename_keys.append((f"trunk.blocks.{i}.norm_2.weight", f"trunk.blocks.{i}.norm_2.weight")) + + rename_keys.append(("trunk.post_trunk_norm.weight", "trunk.post_trunk_norm.weight")) + + # fmt: on + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of a dog +def prepare_img(): + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" + image = Image.open(requests.get(url, stream=True).raw) + return image + + +@torch.no_grad() +def convert_aimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our AIMv2 structure. + """ + model_name_to_repo_id = { + "aimv2-large-patch14-224": "apple/aimv2-large-patch14-224", + "aimv2-huge-patch14-224": "apple/aimv2-huge-patch14-224", + "aimv2-1B-patch14-224": "apple/aimv2-1B-patch14-224", + "aimv2-3B-patch14-224": "apple/aimv2-3B-patch14-224", + "aimv2-large-patch14-336": "apple/aimv2-large-patch14-336", + "aimv2-huge-patch14-336": "apple/aimv2-huge-patch14-336", + "aimv2-1B-patch14-336": "apple/aimv2-1B-patch14-336", + "aimv2-3B-patch14-336": "apple/aimv2-3B-patch14-336", + "aimv2-large-patch14-448": "apple/aimv2-large-patch14-448", + "aimv2-huge-patch14-448": "apple/aimv2-huge-patch14-448", + "aimv2-1B-patch14-448": "apple/aimv2-1B-patch14-448", + "aimv2-3B-patch14-448": "apple/aimv2-3B-patch14-448", + "aimv2-large-patch14-224-distilled": "apple/aimv2-large-patch14-224-distilled", + "aimv2-large-patch14-336-distilled": "apple/aimv2-large-patch14-336-distilled", + "aimv2-large-patch14-native": "apple/aimv2-large-patch14-native", + "aimv2-large-patch14-224-lit": "apple/aimv2-large-patch14-224-lit", + } + + # define default AIMv2 configuration + config = get_aimv2_config(model_name) + logger.info(f"Model config: {config}") + + # load original model from torch hub + repo_id = model_name_to_repo_id[model_name] + filename = "model.safetensors" + + filepath = hf_hub_download(repo_id=repo_id, filename=filename) + state_dict = {} + with safe_open(filepath, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + + # load HuggingFace model + model = AIMv2Model(config).eval() + + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=True) + + # assert values + assert missing_keys == [], str(missing_keys) + assert unexpected_keys == [], str(unexpected_keys) + + # load image + image = prepare_img() + # preprocess image + preprocessor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-224") + inputs = preprocessor(images=image, return_tensors="pt") + + with torch.no_grad(): + model(**inputs) + + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving preprocessor to {pytorch_dump_folder_path}") + preprocessor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub(f"apple/{model_name}") + preprocessor.push_to_hub(f"apple/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="aimv2-large-patch14-224", + type=str, + choices=[ + "aimv2-large-patch14-224", + "aimv2-huge-patch14-224", + "aimv2-1B-patch14-224", + "aimv2-3B-patch14-224", + "aimv2-large-patch14-336", + "aimv2-huge-patch14-336", + "aimv2-1B-patch14-336", + "aimv2-3B-patch14-336", + "aimv2-large-patch14-448", + "aimv2-huge-patch14-448", + "aimv2-1B-patch14-448", + "aimv2-3B-patch14-448", + "aimv2-large-patch14-224-distilled", + "aimv2-large-patch14-336-distilled", + "aimv2-large-patch14-native", + "aimv2-large-patch14-224-lit", + ], + help="Name of the AIMv2 model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the converted model to the 🤗 hub.", + ) + + args = parser.parse_args() + convert_aimv2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py new file mode 100644 index 000000000000..abf502bc4578 --- /dev/null +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -0,0 +1,354 @@ +# coding=utf-8 +# Copyright 2025 Apple and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch AIMv2 model.""" + +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ...modeling_outputs import BaseModelOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings +from .configuration_aimv2 import AIMv2Config + + +__all__ = ["AIMv2Model", "AIMv2PreTrainedModel"] + +_CONFIG_FOR_DOC = "AIMv2Config" + +AIMV2_START_DOCSTRING = r""" + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`AIMv2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +AIMV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`~AutoImageProcessor.__call__`] for details. + + mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to apply to the attention scores. A value of 1 indicates the position is not masked, and a value of 0 + indicates the position is masked. + + + + What is the mask? Most models expect a value of 1, indicating the position *should* attend, and 0, + indicating the position *should not* attend. For example, if your input sequence length is 5 and you only + want to attend to the first 3 positions, the mask should be `[1, 1, 1, 0, 0]`. + + + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class AIMv2RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + output = self._norm(hidden_states.float()).type_as(hidden_states) + return output * self.weight + + def extra_repr(self) -> str: + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + def _norm(self, hidden_states: torch.Tensor) -> torch.Tensor: + return hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) + + +class AIMv2SwiGLUFFN(nn.Module): + def __init__(self, config: AIMv2Config): + super().__init__() + hidden_features = config.intermediate_size + in_features = config.hidden_size + bias = config.use_bias + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.fc2 = nn.Linear(hidden_features, in_features, bias=bias) + self.fc3 = nn.Linear(in_features, hidden_features, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.silu(self.fc1(hidden_states)) * self.fc3(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class AIMv2PatchEmbed(nn.Module): + def __init__(self, config: AIMv2Config): + super().__init__() + self.proj = nn.Conv2d( + config.num_channels, + config.hidden_size, + kernel_size=(config.patch_size, config.patch_size), + stride=(config.patch_size, config.patch_size), + ) + self.norm = AIMv2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + x = self.proj(pixel_values).flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +class AIMv2ViTPreprocessor(nn.Module): + def __init__(self, config: AIMv2Config): + super().__init__() + num_patches = (config.image_size // config.patch_size) ** 2 + + self.patchifier = AIMv2PatchEmbed(config) + self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size))) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + tokens = self.patchifier(pixel_values) + _, num_tokens, _ = tokens.shape + pos_embed = self.pos_embed.to(tokens.device) + tokens = tokens + pos_embed[:, :num_tokens] + return tokens + + +class AIMv2Attention(nn.Module): + def __init__(self, config: AIMv2Config): + super().__init__() + hidden_size = config.hidden_size + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.qkv = nn.Linear(hidden_size, self.all_head_size * 3, bias=config.qkv_bias) + self.attn_drop = nn.Dropout(config.attention_dropout) + self.proj = nn.Linear(self.all_head_size, hidden_size, bias=config.use_bias) + self.proj_drop = nn.Dropout(config.projection_dropout) + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, seq_length, _ = hidden_states.shape + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, seq_length, 3, self.num_attention_heads, self.attention_head_size) + .permute(2, 0, 3, 1, 4) + ) + query, key, value = qkv.unbind(0) + + context_layer = F.scaled_dot_product_attention(query, key, value, attn_mask=mask) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + context_layer = self.proj(context_layer) + context_layer = self.proj_drop(context_layer) + + return context_layer + + +class AIMv2Block(nn.Module): + def __init__(self, config: AIMv2Config): + super().__init__() + self.attn = AIMv2Attention(config) + self.norm_1 = AIMv2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = AIMv2SwiGLUFFN(config) + self.norm_2 = AIMv2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = hidden_states + self.attn(self.norm_1(hidden_states), mask) + hidden_states = hidden_states + self.mlp(self.norm_2(hidden_states)) + return hidden_states + + +class AIMv2Transformer(nn.Module): + def __init__(self, config: AIMv2Config): + super().__init__() + self.blocks = nn.ModuleList([AIMv2Block(config) for _ in range(config.num_hidden_layers)]) + self.post_trunk_norm = AIMv2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + tokens: torch.Tensor, + mask: Optional[torch.Tensor] = None, + output_hidden_states: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]: + hidden_states = () if output_hidden_states else None + for block in self.blocks: + if output_hidden_states: + hidden_states += (tokens,) + + tokens = block(tokens, mask) + + if output_hidden_states: + hidden_states += (tokens,) + + tokens = self.post_trunk_norm(tokens) + return tokens, hidden_states + + +@add_start_docstrings( + "The bare AIMv2 Model transformer outputting raw hidden-states without any specific head on top.", + AIMV2_START_DOCSTRING, +) +class AIMv2PreTrainedModel(PreTrainedModel): + config_class = AIMv2Config + base_model_prefix = "aimv2" + main_input_name = "pixel_values" + _no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2Block"] + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Conv2d): + # Use Kaiming initialization for convolutional layers + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, AIMv2ViTPreprocessor): + module.pos_embed.data = nn.init.trunc_normal_( + module.pos_embed.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.pos_embed.dtype) + + +@add_start_docstrings( + "The bare AIMv2 Model transformer outputting raw hidden-states without any specific head on top.", + AIMV2_START_DOCSTRING, +) +class AIMv2Model(AIMv2PreTrainedModel): + def __init__(self, config: AIMv2Config): + super().__init__(config) + self.preprocessor = AIMv2ViTPreprocessor(config) + self.trunk = AIMv2Transformer(config) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, AIMv2RMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, AIMv2ViTPreprocessor): + module.pos_embed.data = nn.init.trunc_normal_( + module.pos_embed.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.pos_embed.dtype) + + def get_input_embeddings(self) -> AIMv2PatchEmbed: + return self.preprocessor.patchifier + + def forward( + self, + pixel_values: torch.Tensor, + mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + ) -> Union[ + Tuple[torch.Tensor], + Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], + BaseModelOutputWithNoAttention, + ]: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`~AutoImageProcessor.__call__`] for details. + + mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to apply to the attention scores. A value of 1 indicates the position is not masked, and a value of 0 + indicates the position is masked. + + + + What is the mask? Most models expect a value of 1, indicating the position *should* attend, and 0, + indicating the position *should not* attend. For example, if your input sequence length is 5 and you only + want to attend to the first 3 positions, the mask should be `[1, 1, 1, 0, 0]`. + + + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + + Returns: + Returns a tuple if not dictionary if config.use_return_dict is set to True, else a tuple. + x (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Hidden states of the output at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True` is passed or if `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, AIMv2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-224") + >>> model = AIMv2Model.from_pretrained("apple/aimv2-large-patch14-224") + + >>> inputs = processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + x = self.preprocessor(pixel_values) + x, hidden_states = self.trunk(x, mask, output_hidden_states=output_hidden_states) + + if not return_dict: + res = (x,) + res += (hidden_states,) if output_hidden_states else () + return res + + return BaseModelOutputWithNoAttention( + last_hidden_state=x, + hidden_states=hidden_states if output_hidden_states else None, + ) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 69ce8efa10c7..431e7317a46e 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -32,6 +32,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here + ("aimv2", "AIMv2Config"), ("albert", "AlbertConfig"), ("align", "AlignConfig"), ("altclip", "AltCLIPConfig"), @@ -332,6 +333,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("aimv2", "AIMv2"), ("albert", "ALBERT"), ("align", "ALIGN"), ("altclip", "AltCLIP"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e8a2dece4324..c25158abe23a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -32,6 +32,7 @@ MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping + ("aimv2", "AIMv2Model"), ("albert", "AlbertModel"), ("align", "AlignModel"), ("altclip", "AltCLIPModel"), @@ -572,6 +573,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict( [ # Model for Image mapping + ("aimv2", "AIMv2Model"), ("beit", "BeitModel"), ("bit", "BitModel"), ("conditional_detr", "ConditionalDetrModel"), @@ -646,6 +648,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image Classification mapping + ("aimv2", "AIMv2ForImageClassification"), ("beit", "BeitForImageClassification"), ("bit", "BitForImageClassification"), ("clip", "CLIPForImageClassification"), diff --git a/tests/models/aimv2/test_modeling_aimv2.py b/tests/models/aimv2/test_modeling_aimv2.py new file mode 100644 index 000000000000..bed0eb8f035b --- /dev/null +++ b/tests/models/aimv2/test_modeling_aimv2.py @@ -0,0 +1,293 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch AIMv2 model.""" + +import unittest + +from transformers import AIMv2Config +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + import torch.nn as nn + + from transformers import AIMv2Model + +if is_vision_available(): + from PIL import Image + + from transformers import AutoImageProcessor + + +class AIMv2ModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="silu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + scope=None, + qkv_bias=True, + use_bias=False, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + self.qkv_bias = qkv_bias + self.use_bias = use_bias + + # in AIMv2, the seq length equals the number of patches + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return AIMv2Config( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + initializer_range=self.initializer_range, + qkv_bias=self.qkv_bias, + use_bias=self.use_bias, + ) + + def create_and_check_model(self, config, pixel_values): + model = AIMv2Model(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class AIMv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as AIMv2 does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (AIMv2Model,) if is_torch_available() else () + pipeline_model_mapping = {"image-feature-extraction": AIMv2Model} if is_torch_available() else {} + fx_compatible = True + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = False + + def setUp(self): + self.model_tester = AIMv2ModelTester(self) + self.config_tester = ConfigTester(self, config_class=AIMv2Config, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="AIMv2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + x = model.get_input_embeddings() + self.assertTrue(x is not None and isinstance(x, nn.Module)) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, param in model.named_parameters(): + if param.requires_grad: + # Check if mean is within a reasonable range around 0.0 + self.assertTrue( + -0.02 <= param.data.mean().item() <= 0.02, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @unittest.skip("Test is designed for torch-fx tracing which is currently not supported") + def test_torch_fx(self): + pass + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="AIMv2 does not support feedforward chunking yet") + def test_feed_forward_chunking(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "apple/aimv2-large-patch14-224" + model = AIMv2Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + @unittest.skip(reason="AIMv2 does not output attentions") + def test_attention_outputs(self): + pass + + @unittest.skip("Test is designed for torch-fx tracing which is currently not supported") + def test_torch_fx_tracing(self): + pass + + @unittest.skip("Test is designed for torch-fx tracing which is currently not supported") + def test_torch_fx_output_loss(self): + pass + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + seq_length = self.model_tester.seq_length + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + +# We will verify our results on an image of a dog +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class AIMv2ModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-224") if is_vision_available() else None + + @slow + def test_inference_no_head(self): + model = AIMv2Model.from_pretrained("apple/aimv2-large-patch14-224").to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the last hidden states + expected_shape = torch.Size((1, 256, 1024)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [ + [0.0509, 0.0806, -0.0989], + [2.7847, -2.5148, -0.3327], + [2.8176, -2.4086, -0.2774], + ], # Replace this with the actual output from your model + device=torch_device, + ) + + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) From 94352ced971604aebe01a93c62238ca08fba3d91 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Tue, 7 Jan 2025 23:40:38 +0530 Subject: [PATCH 2/4] further improvements --- src/transformers/__init__.py | 8 ++ .../models/aimv2/modeling_aimv2.py | 92 +++++++++---------- 2 files changed, 54 insertions(+), 46 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7966c1feab9d..8767b2869365 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -158,6 +158,7 @@ ], # Models "models": [], + "models.aimv2": ["AIMv2Config"], "models.albert": ["AlbertConfig"], "models.align": [ "AlignConfig", @@ -1404,6 +1405,13 @@ # PyTorch models structure + _import_structure["models.aimv2"].extend( + [ + "AIMv2Model", + "AIMv2PreTrainedModel", + ] + ) + _import_structure["models.albert"].extend( [ "AlbertForMaskedLM", diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index abf502bc4578..26cff0439d54 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -280,59 +280,59 @@ def forward( BaseModelOutputWithNoAttention, ]: """ - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`~AutoImageProcessor.__call__`] for details. + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`~AutoImageProcessor.__call__`] for details. - mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to apply to the attention scores. A value of 1 indicates the position is not masked, and a value of 0 - indicates the position is masked. + mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to apply to the attention scores. A value of 1 indicates the position is not masked, and a value of 0 + indicates the position is masked. - + - What is the mask? Most models expect a value of 1, indicating the position *should* attend, and 0, - indicating the position *should not* attend. For example, if your input sequence length is 5 and you only - want to attend to the first 3 positions, the mask should be `[1, 1, 1, 0, 0]`. + What is the mask? Most models expect a value of 1, indicating the position *should* attend, and 0, + indicating the position *should not* attend. For example, if your input sequence length is 5 and you only + want to attend to the first 3 positions, the mask should be `[1, 1, 1, 0, 0]`. - + - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - - Returns: - Returns a tuple if not dictionary if config.use_return_dict is set to True, else a tuple. - x (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Hidden states of the output at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True` is passed or if `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of - the model at the output of each layer plus the optional initial embedding outputs. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. Returns: - - Examples: - ```python - >>> from transformers import AutoImageProcessor, AIMv2Model - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-224") - >>> model = AIMv2Model.from_pretrained("apple/aimv2-large-patch14-224") - - >>> inputs = processor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - ``` + Returns a tuple if not dictionary if config.use_return_dict is set to True, else a tuple. + x (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Hidden states of the output at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True` is passed or if `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, AIMv2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-224") + >>> model = AIMv2Model.from_pretrained("apple/aimv2-large-patch14-224") + + >>> inputs = processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( From 14cea1d469c22558bab5b542a2b3941fda2b58fa Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Tue, 7 Jan 2025 23:49:05 +0530 Subject: [PATCH 3/4] add init to tests --- tests/models/aimv2/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/models/aimv2/__init__.py diff --git a/tests/models/aimv2/__init__.py b/tests/models/aimv2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From d8ab0e8eff2a630648a2c86744399fa887762fd4 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sat, 25 Jan 2025 13:36:00 +0530 Subject: [PATCH 4/4] modified init --- docs/source/en/model_doc/aimv2.md | 2 +- src/transformers/models/aimv2/__init__.py | 41 ++++------------------- 2 files changed, 8 insertions(+), 35 deletions(-) diff --git a/docs/source/en/model_doc/aimv2.md b/docs/source/en/model_doc/aimv2.md index fc6c084db4da..3650fd9d14e1 100644 --- a/docs/source/en/model_doc/aimv2.md +++ b/docs/source/en/model_doc/aimv2.md @@ -1,4 +1,4 @@ -