diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 016d7279353d..f1b930db7ee0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -649,6 +649,8 @@ title: GLPN - local: model_doc/hiera title: Hiera + - local: model_doc/imagebind + title: ImageBind - local: model_doc/imagegpt title: ImageGPT - local: model_doc/levit diff --git a/docs/source/en/index.md b/docs/source/en/index.md index bdea11a2456f..757278102d2c 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -170,6 +170,7 @@ Flax), PyTorch, and/or TensorFlow. | [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ | | [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ | | [Idefics3](model_doc/idefics3) | ✅ | ❌ | ❌ | +| [ImageBind](model_doc/imagebind) | ✅ | ❌ | ❌ | | [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ | | [Informer](model_doc/informer) | ✅ | ❌ | ❌ | | [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/imagebind.md b/docs/source/en/model_doc/imagebind.md new file mode 100644 index 000000000000..0fbdaf72c927 --- /dev/null +++ b/docs/source/en/model_doc/imagebind.md @@ -0,0 +1,141 @@ + + +# ImageBind + +## Overview + +The ImageBind model was proposed in [ImageBind: One Embedding Space To Bind Them All](https://arxiv.org/abs/2305.05665) by Rohit Girdhar, Alaaeldin El-Nouby, Zhuang Liu, Mannat Singh, Kalyan Vasudev Alwala, Armand Joulin, Ishan Misra. +ImageBind is a multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images. +For any input from these six modalities, it outputs the same-sized embedding that can be used for cross-modal and multimodal tasks. + +The abstract from the paper is the following: + +*We present ImageBind, an approach to learn a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. We show that all combinations of paired data are not necessary to train such a joint embedding, and only image-paired data is sufficient to bind the modalities together. ImageBind can leverage recent large scale vision-language models, and extends their zero-shot capabilities to new modalities just by using their natural pairing with images. It enables novel emergent applications 'out-of-the-box' including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation. The emergent capabilities improve with the strength of the image encoder and we set a new state-of-the-art on emergent zero-shot recognition tasks across modalities, outperforming specialist supervised models. Finally, we show strong few-shot recognition results outperforming prior work, and that ImageBind serves as a new way to evaluate vision models for visual and non-visual tasks.* + +This model was contributed by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [ruffy369](https://huggingface.co/ruffy369) and [dg845](https://huggingface.co/dg845) and [shehan97](https://huggingface.co/shehan97). +The original code can be found [here](https://github.com/facebookresearch/ImageBind). + +## Usage tips + +- ImageBind can be used for multi-modality similarity and zero-shot tasks. +- Currently only Vision (image and video), Audio and Text are supported. +- One can use [`ImageBindProcessor`] to prepare all or pairs of the available modalities. +- [`ImageBindModel`] `forward` expects only one pair of modalities where one of those MUST be vision modality. +- If interest only on the modalities embeddings one can use [`ImageBindModel`] `get_xxx_features` method or the appropriate `ImageBindXxxModelWithProjection` +- As ImageBind vision and text encoders were frozen during training and are initialized with OpenCLIP ViT-H if one has an application using this model the addition of other modalities by including other encoders would be possible. + +Here's one example of how to get the embeddings for images, text and audios (this example requires `torchaudio`!) + +```python +import torch +import torchaudio +from datasets import load_dataset +from transformers import ImageBindModel, ImageBindProcessor + +ds = load_dataset("EduardoPacheco/imagebind-example-data", split="train") +images = ds["image"] +text = ds["text"] +audios = ds["audio"] # It's a dict with keys -> array and sampling_rate +audios = [ + torchaudio.functional.resample( + torch.from_numpy(audio["array"]), + orig_freq=audio["sampling_rate"], + new_freq=16000 + ).numpy() + for audio in audios +] + +model = ImageBindModel.from_pretrained("EduardoPacheco/imagebind-huge") +processor = ImageBindProcessor.from_pretrained("EduardoPacheco/imagebind-huge") + +inputs = processor(text=text, images=images, audios=audios, padding=True, return_tensors="pt") + +with torch.no_grad(): + audio_embeds = model.get_audio_features(input_features=inputs.input_features) + image_embeds = model.get_image_features(pixel_values=inputs.pixel_values) + text_embeds = model.get_text_features(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) + +# we can compute probs to use for retrieval or zero-shot workflows. +probs_image_text = (image_embeds @ text_embeds.T).softmax(dim=-1) +probs_text_audio = (text_embeds @ audio_embeds.T).softmax(dim=-1) +probs_image_audio = (image_embeds @ audio_embeds.T).softmax(dim=-1) +``` + +## ImageBindConfig + +[[autodoc]] ImageBindConfig + - from_text_vision_configs + +## ImageBindTextConfig + +[[autodoc]] ImageBindTextConfig + +## ImageBindVisionConfig + +[[autodoc]] ImageBindVisionConfig + +## ImageBindAudioConfig + +[[autodoc]] ImageBindAudioConfig + +## ImageBindImageProcessor + +[[autodoc]] ImageBindImageProcessor + - preprocess + +## ImageBindFeatureExtractor + +[[autodoc]] ImageBindFeatureExtractor + +## ImageBindProcessor + +[[autodoc]] ImageBindProcessor + +## ImageBindModel + +[[autodoc]] ImageBindModel + - forward + - get_text_features + - get_image_features + - get_audio_features + +## ImageBindTextModel + +[[autodoc]] ImageBindTextModel + - forward + +## ImageBindTextModelWithProjection + +[[autodoc]] ImageBindTextModelWithProjection + - forward + +## ImageBindVisionModel + +[[autodoc]] ImageBindVisionModel + - forward + + +## ImageBindVisionModelWithProjection + +[[autodoc]] ImageBindVisionModelWithProjection + - forward + +## ImageBindAudioModel + +[[autodoc]] ImageBindAudioModel + - forward + +## ImageBindAudioModelWithProjection + +[[autodoc]] ImageBindAudioModelWithProjection + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 236333fb1cbd..acdc3b11acfb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -481,6 +481,14 @@ "models.idefics": ["IdeficsConfig"], "models.idefics2": ["Idefics2Config"], "models.idefics3": ["Idefics3Config"], + "models.imagebind": [ + "ImageBindAudioConfig", + "ImageBindConfig", + "ImageBindFeatureExtractor", + "ImageBindProcessor", + "ImageBindTextConfig", + "ImageBindVisionConfig", + ], "models.imagegpt": ["ImageGPTConfig"], "models.informer": ["InformerConfig"], "models.instructblip": [ @@ -1200,6 +1208,7 @@ _import_structure["models.idefics"].extend(["IdeficsImageProcessor"]) _import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"]) _import_structure["models.idefics3"].extend(["Idefics3ImageProcessor"]) + _import_structure["models.imagebind"].extend(["ImageBindImageProcessor"]) _import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"]) _import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"]) _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"]) @@ -2439,6 +2448,18 @@ "Idefics3Processor", ] ) + _import_structure["models.imagebind"].extend( + [ + "ImageBindAudioModel", + "ImageBindAudioModelWithProjection", + "ImageBindModel", + "ImageBindPreTrainedModel", + "ImageBindTextModel", + "ImageBindTextModelWithProjection", + "ImageBindVisionModel", + "ImageBindVisionModelWithProjection", + ] + ) _import_structure["models.imagegpt"].extend( [ "ImageGPTForCausalImageModeling", @@ -5337,6 +5358,14 @@ ) from .models.idefics2 import Idefics2Config from .models.idefics3 import Idefics3Config + from .models.imagebind import ( + ImageBindAudioConfig, + ImageBindConfig, + ImageBindFeatureExtractor, + ImageBindProcessor, + ImageBindTextConfig, + ImageBindVisionConfig, + ) from .models.imagegpt import ImageGPTConfig from .models.informer import InformerConfig from .models.instructblip import ( @@ -6094,6 +6123,7 @@ from .models.idefics import IdeficsImageProcessor from .models.idefics2 import Idefics2ImageProcessor from .models.idefics3 import Idefics3ImageProcessor + from .models.imagebind import ImageBindImageProcessor from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor from .models.instructblipvideo import InstructBlipVideoImageProcessor from .models.layoutlmv2 import ( @@ -7136,6 +7166,16 @@ Idefics3PreTrainedModel, Idefics3Processor, ) + from .models.imagebind import ( + ImageBindAudioModel, + ImageBindAudioModelWithProjection, + ImageBindModel, + ImageBindPreTrainedModel, + ImageBindTextModel, + ImageBindTextModelWithProjection, + ImageBindVisionModel, + ImageBindVisionModelWithProjection, + ) from .models.imagegpt import ( ImageGPTForCausalImageModeling, ImageGPTForImageClassification, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 069c7f90564f..8e9450ac0003 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -116,6 +116,7 @@ idefics, idefics2, idefics3, + imagebind, imagegpt, informer, instructblip, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 05d6e717be23..a5b9fc8ef3b0 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -134,6 +134,7 @@ ("idefics", "IdeficsConfig"), ("idefics2", "Idefics2Config"), ("idefics3", "Idefics3Config"), + ("imagebind", "ImageBindConfig"), ("imagegpt", "ImageGPTConfig"), ("informer", "InformerConfig"), ("instructblip", "InstructBlipConfig"), @@ -437,6 +438,7 @@ ("idefics", "IDEFICS"), ("idefics2", "Idefics2"), ("idefics3", "Idefics3"), + ("imagebind", "ImageBind"), ("imagegpt", "ImageGPT"), ("informer", "Informer"), ("instructblip", "InstructBLIP"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 0ddab5681f2e..389b8a3b34a5 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -63,6 +63,7 @@ ("glpn", "GLPNFeatureExtractor"), ("groupvit", "CLIPFeatureExtractor"), ("hubert", "Wav2Vec2FeatureExtractor"), + ("imagebind", "ImageBindFeatureExtractor"), ("imagegpt", "ImageGPTFeatureExtractor"), ("layoutlmv2", "LayoutLMv2FeatureExtractor"), ("layoutlmv3", "LayoutLMv3FeatureExtractor"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index ef40798484ef..2a2e856e52ee 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -90,6 +90,7 @@ ("idefics", ("IdeficsImageProcessor",)), ("idefics2", ("Idefics2ImageProcessor",)), ("idefics3", ("Idefics3ImageProcessor",)), + ("imagebind", ("ImageBindImageProcessor",)), ("imagegpt", ("ImageGPTImageProcessor",)), ("instructblip", ("BlipImageProcessor",)), ("instructblipvideo", ("InstructBlipVideoImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5a98e761adc1..23526d428aa8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -131,6 +131,7 @@ ("idefics", "IdeficsModel"), ("idefics2", "Idefics2Model"), ("idefics3", "Idefics3Model"), + ("imagebind", "ImageBindModel"), ("imagegpt", "ImageGPTModel"), ("informer", "InformerModel"), ("jamba", "JambaModel"), @@ -1328,6 +1329,7 @@ ("chinese_clip", "ChineseCLIPModel"), ("clip", "CLIPModel"), ("clipseg", "CLIPSegModel"), + ("imagebind", "ImageBindModel"), ("siglip", "SiglipModel"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c1f23bc1cb3f..11dd2ce56131 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -66,6 +66,7 @@ ("idefics", "IdeficsProcessor"), ("idefics2", "Idefics2Processor"), ("idefics3", "Idefics3Processor"), + ("imagebind", "ImageBindProcessor"), ("instructblip", "InstructBlipProcessor"), ("instructblipvideo", "InstructBlipVideoProcessor"), ("kosmos-2", "Kosmos2Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 63549202969a..b3eac8fed7d0 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -220,6 +220,13 @@ ("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ( + "imagebind", + ( + "CLIPTokenizer", + "CLIPTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/imagebind/__init__.py b/src/transformers/models/imagebind/__init__.py new file mode 100644 index 000000000000..e45da3df704a --- /dev/null +++ b/src/transformers/models/imagebind/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_imagebind import * + from .feature_extraction_imagebind import * + from .image_processing_imagebind import * + from .modeling_imagebind import * + from .processing_imagebind import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/imagebind/configuration_imagebind.py b/src/transformers/models/imagebind/configuration_imagebind.py new file mode 100644 index 000000000000..bd2765b29c4c --- /dev/null +++ b/src/transformers/models/imagebind/configuration_imagebind.py @@ -0,0 +1,546 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ImageBind model configuration""" + +import copy +import os +from typing import Any, Dict, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ImageBindTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ImageBindTextModel`]. It is used to instantiate a ImageBind + text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the text encoder of the ImageBind + [facebook/imagebind-huge](https://huggingface.co/facebook/imagebind-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the ImageBind text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`ImageBindModel`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden size in the feedforward network to the hidden size in the encoder layers. + projection_dim (`int`, *optional*, defaults to 1024): + If the ImageBind text model has an output projection layer, the dimension to which that projection layer + maps to. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 77): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + add_kv_bias (`bool`, *optional*, defaults to `False`): + Whether to add an extra learnable bias token to the attention key and value sequences. This is based on the + `add_kv_bias` argument to [`torch.nn.MultiHeadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for the DropPath (stochastic) regularization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + logit_scale_init_value (`float`, *optional*, defaults to 14.2857): + The initial value of the `logit_scale` parameter for the text component. If `None`, the logits will not + be scaled. + learnable_logit_scale (`bool`, *optional*, defaults to `True`): + Whether the `logit_scale` is learnable or fixed. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 49406): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 49407): + End of stream token id. + intermediate_size (`int`, *optional*, defaults to 4096): + Abstract intermediate size for MLP class. Always equal to hidden_size * mlp_ratio. + + Example: + + ```python + >>> from transformers import ImageBindTextConfig, ImageBindTextModel + + >>> # Initializing a ImageBindTextConfig with facebook/imagebind-huge style configuration + >>> configuration = ImageBindTextConfig() + + >>> # Initializing a ImageBindTextModel (with random weights) from the facebook/imagebind-huge style configuration + >>> model = ImageBindTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "imagebind_text_model" + + def __init__( + self, + vocab_size=49408, + hidden_size=1024, + mlp_ratio=4.0, + projection_dim=1024, + num_hidden_layers=24, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-6, + add_kv_bias=False, + attention_dropout=0.0, + drop_path_rate=0.0, + initializer_range=0.02, + initializer_factor=1.0, + logit_scale_init_value=14.2857, + learnable_logit_scale=True, + pad_token_id=0, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.mlp_ratio = mlp_ratio + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.add_kv_bias = add_kv_bias + self.attention_dropout = attention_dropout + self.drop_path_rate = drop_path_rate + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.logit_scale_init_value = logit_scale_init_value + self.learnable_logit_scale = learnable_logit_scale + self.intermediate_size = int(hidden_size * mlp_ratio) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from ImageBindConfig + if config_dict.get("model_type") == "imagebind": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class ImageBindVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ImageBindVisionModel`]. It is used to instantiate a + ImageBind vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the ImageBind + [facebook/imagebind-huge](https://huggingface.co/facebook/imagebind-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1280): + Dimensionality of the encoder layers and the pooler layer. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden size in the feedforward network to the hidden size in the encoder layers. + projection_dim (`int`, *optional*, defaults to 1024): + If the ImageBind vision model has an output projection layer, the dimension to which that projection layer + maps to. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of channels in the input images. + num_frames (`int`, *optional*, defaults to 2): + If using video (spatiotemporal) input, the number of video frames in the spatiotemporal data. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + add_kv_bias (`bool`, *optional*, defaults to `False`): + Whether to add an extra learnable bias token to the attention key and value sequences. This is based on the + `add_kv_bias` argument to [`torch.nn.MultiHeadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for the DropPath (stochastic) regularization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + logit_scale_init_value (`float`, *optional*): + The initial value of the `logit_scale` parameter for the vision component. If `None`, the logits will not + be scaled. + learnable_logit_scale (`bool`, *optional*, defaults to `False`): + Whether the `logit_scale` is learnable or fixed. + feature_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image or feature (equal to image_size), for abstraction of image_size in modeling file. + intermediate_size (`int`, *optional*, defaults to 5120): + Abstract intermediate size for MLP class. Always equal to hidden_size * mlp_ratio. + + Example: + + ```python + >>> from transformers import ImageBindVisionConfig, ImageBindVisionModel + + >>> # Initializing a ImageBindVisionConfig with facebook/imagebind-huge style configuration + >>> configuration = ImageBindVisionConfig() + + >>> # Initializing a ImageBindVisionModel (with random weights) from the facebook/imagebind-huge style configuration + >>> model = ImageBindVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "imagebind_vision_model" + + def __init__( + self, + hidden_size=1280, + mlp_ratio=4.0, + projection_dim=1024, + num_hidden_layers=32, + num_attention_heads=16, + num_channels=3, + num_frames=2, + image_size=224, + patch_size=14, + hidden_act="gelu", + layer_norm_eps=1e-6, + add_kv_bias=False, + attention_dropout=0.0, + drop_path_rate=0.0, + initializer_range=0.02, + initializer_factor=1.0, + logit_scale_init_value=None, + learnable_logit_scale=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.mlp_ratio = mlp_ratio + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.num_frames = num_frames + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.add_kv_bias = add_kv_bias + self.attention_dropout = attention_dropout + self.drop_path_rate = drop_path_rate + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.logit_scale_init_value = logit_scale_init_value + self.learnable_logit_scale = learnable_logit_scale + self.feature_size = image_size + self.intermediate_size = int(hidden_size * mlp_ratio) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from ImageBindConfig + if config_dict.get("model_type") == "imagebind": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class ImageBindAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ImageBindAudioModel`]. It is used to instantiate a + ImageBind audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the ImageBind + [facebook/imagebind-huge](https://huggingface.co/facebook/imagebind-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden size in the feedforward network to the hidden size in the encoder layers. + projection_dim (`int`, *optional*, defaults to 1024): + If the ImageBind audio model has an output projection layer, the dimension to which that projection layer + maps to. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_mel_bins (`int`, *optional*, defaults to 128): + The number of frequency bins in the log-mel spectrogram. + target_len (`int`, *optional*, defaults to 204): + The length of the target sequence. + num_channels (`int`, *optional*, defaults to 1): + The number of channels in the input audio data. + patch_size (`int`, *optional*, defaults to 16): + The kernel size of the patch embedding 2D convolution layer. + stride (`int`, *optional*, defaults to 10): + The stride of the patch embedding 2D convolution layer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + add_kv_bias (`bool`, *optional*, defaults to `True`): + Whether to add an extra learnable bias token to the attention key and value sequences. This is based on the + `add_kv_bias` argument to [`torch.nn.MultiHeadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + The dropout probability for the DropPath (stochastic) regularization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + logit_scale_init_value (`float`, *optional*, defaults to 20.0): + The initial value of the `logit_scale` parameter for the audio component. If `None`, the logits will not + be scaled. + learnable_logit_scale (`bool`, *optional*, defaults to `False`): + Whether the `logit_scale` is learnable or fixed. + feature_size (`Tuple[int, int]`, *optional*, defaults to (128, 204)): + The size (resolution) of audio feature (equal to (num_mel_bins, target_len)), for abstraction of image_size in modeling file. + intermediate_size (`int`, *optional*, defaults to 3072): + Abstract intermediate size for MLP class. Always equal to hidden_size * mlp_ratio. + + Example: + ```python + >>> from transformers import ImageBindAudioConfig, ImageBindAudioModel + + >>> # Initializing a ImageBindAudioConfig with facebook/imagebind-huge style configuration + >>> configuration = ImageBindAudioConfig() + + >>> # Initializing a ImageBindAudioModel (with random weights) from the facebook/imagebind-huge style configuration + >>> model = ImageBindAudioModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + hidden_size=768, + mlp_ratio=4.0, + projection_dim=1024, + num_hidden_layers=12, + num_attention_heads=12, + num_mel_bins=128, + target_len=204, + num_channels=1, + patch_size=16, + stride=10, + hidden_act="gelu", + layer_norm_eps=1e-6, + add_kv_bias=True, + attention_dropout=0.0, + drop_path_rate=0.1, + initializer_range=0.02, + initializer_factor=1.0, + logit_scale_init_value=20.0, + learnable_logit_scale=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.mlp_ratio = mlp_ratio + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_mel_bins = num_mel_bins + self.target_len = target_len + self.num_channels = num_channels + self.patch_size = patch_size + self.stride = stride + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.add_kv_bias = add_kv_bias + self.attention_dropout = attention_dropout + self.drop_path_rate = drop_path_rate + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.logit_scale_init_value = logit_scale_init_value + self.learnable_logit_scale = learnable_logit_scale + self.feature_size = (num_mel_bins, target_len) + self.intermediate_size = int(hidden_size * mlp_ratio) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the audio config dict if we are loading from ImageBindConfig + if config_dict.get("model_type") == "imagebind": + config_dict = config_dict["audio_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class ImageBindConfig(PretrainedConfig): + r""" + [`ImageBindConfig`] is the configuration class to store the configuration of a [`ImageBindModel`]. It is used to instantiate + a ImageBind model according to the specified arguments, defining the text model and vision model configs. Instantiating + a configuration with the defaults will yield a similar configuration to that of the ImageBind + [facebook/imagebind-huge](https://huggingface.co/facebook/imagebind-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict` or `ImageBindTextConfig`, *optional*): + Dictionary or an instance of `ImageBindTextConfig` that defines the text model configuration. + vision_config (`dict` or `ImageBindVisionConfig`, *optional*): + Dictionary or an instance of `ImageBindVisionConfig` that defines the vision model configuration. + audio_config (`dict` or `ImageBindAudioConfig`, *optional*): + Dictionary or an instance of `ImageBindAudioConfig` that defines the audio model configuration. + projection_dim (`int`, *optional*, defaults to 1024): + Dimentionality of text and vision projection layers. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ImageBindConfig, ImageBindModel + + >>> # Initializing a ImageBindConfig with facebook/imagebind-huge style configuration + >>> configuration = ImageBindConfig() + + >>> # Initializing a ImageBindModel (with random weights) from the facebook/imagebind-huge style configuration + >>> model = ImageBindModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a ImageBindConfig from a ImageBindTextConfig and a ImageBindVisionConfig + >>> from transformers import ImageBindTextConfig, ImageBindVisionConfig + + >>> # Initializing a ImageBindText and ImageBindVision configuration + >>> config_text = ImageBindTextConfig() + >>> config_vision = ImageBindVisionConfig() + + >>> config = ImageBindConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "imagebind" + is_composition = True + + def __init__( + self, + text_config: Optional[Union[Dict[str, Any], ImageBindTextConfig]] = None, + vision_config: Optional[Union[Dict[str, Any], ImageBindVisionConfig]] = None, + audio_config: Optional[Union[Dict[str, Any], ImageBindAudioConfig]] = None, + projection_dim: int = 1024, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `ImageBindTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `ImageBindVisionConfig` with default values.") + + if audio_config is None: + audio_config = {} + logger.info("`audio_config` is `None`. initializing the `ImageBindAudioConfig` with default values.") + + self.text_config = ImageBindTextConfig(**text_config) if isinstance(text_config, dict) else text_config + self.vision_config = ( + ImageBindVisionConfig(**vision_config) if isinstance(vision_config, dict) else vision_config + ) + self.audio_config = ImageBindAudioConfig(**audio_config) if isinstance(audio_config, dict) else audio_config + + self.projection_dim = projection_dim + self.initializer_factor = 1.0 + + @classmethod + # Copied from transformers.models.clip.configuration_clip.CLIPConfig.from_text_vision_configs with CLIP->ImageBind, clip->imagebind + def from_text_vision_configs( + cls, text_config: ImageBindTextConfig, vision_config: ImageBindVisionConfig, **kwargs + ): + r""" + Instantiate a [`ImageBindConfig`] (or a derived class) from imagebind text model configuration and imagebind vision model + configuration. + + Returns: + [`ImageBindConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_config"] = self.text_config.to_dict() + output["vision_config"] = self.vision_config.to_dict() + output["audio_config"] = self.audio_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + +__all__ = ["ImageBindTextConfig", "ImageBindVisionConfig", "ImageBindAudioConfig", "ImageBindConfig"] diff --git a/src/transformers/models/imagebind/convert_imagebind_to_hf.py b/src/transformers/models/imagebind/convert_imagebind_to_hf.py new file mode 100644 index 000000000000..1a2a7e056def --- /dev/null +++ b/src/transformers/models/imagebind/convert_imagebind_to_hf.py @@ -0,0 +1,210 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import regex as re +import torch + +from transformers import ( + CLIPTokenizer, + ImageBindConfig, + ImageBindFeatureExtractor, + ImageBindImageProcessor, + ImageBindModel, + ImageBindProcessor, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Vision + r"modality_preprocessors\.vision\.cls_token": "vision_model.embeddings.cls_token", + r"modality_preprocessors\.vision\.rgbt_stem\.proj\.1\.weight": "vision_model.embeddings.patch_embedding.projection.weight", + r"modality_preprocessors\.vision\.pos_embedding_helper\.pos_embed": "vision_model.embeddings.position_embeddings", + r"modality_heads\.vision\.0\.weight": "vision_model.layernorm.weight", + r"modality_heads\.vision\.0\.bias": "vision_model.layernorm.bias", + r"modality_heads\.vision\.2\.weight": "vision_projection.weight", + r"modality_trunks\.vision\.pre_transformer_layer\.0\.weight": "vision_model.pre_layernorm.weight", + r"modality_trunks\.vision\.pre_transformer_layer\.0\.bias": "vision_model.pre_layernorm.bias", + # Text + r"modality_preprocessors\.text\.pos_embed": "text_model.embeddings.position_embedding.weight", + r"modality_preprocessors\.text\.token_embedding\.weight": "text_model.embeddings.token_embedding.weight", + r"modality_heads\.text\.proj\.0\.weight": "text_model.layernorm.weight", + r"modality_heads\.text\.proj\.0\.bias": "text_model.layernorm.bias", + r"modality_heads\.text\.proj\.1\.weight": "text_projection.weight", + r"modality_postprocessors\.text\.1\.log_logit_scale": "text_postprocessor.log_logit_scale", + # Audio + r"modality_preprocessors\.audio\.cls_token": "audio_model.embeddings.cls_token", + r"modality_preprocessors\.audio\.rgbt_stem\.proj\.weight": "audio_model.embeddings.patch_embedding.projection.weight", + r"modality_preprocessors\.audio\.rgbt_stem\.norm_layer\.weight": "audio_model.embeddings.patch_embedding.layernorm.weight", + r"modality_preprocessors\.audio\.rgbt_stem\.norm_layer\.bias": "audio_model.embeddings.patch_embedding.layernorm.bias", + r"modality_preprocessors\.audio\.pos_embedding_helper\.pos_embed": "audio_model.embeddings.position_embeddings", + r"modality_heads\.audio\.0\.weight": "audio_model.layernorm.weight", + r"modality_heads\.audio\.0\.bias": "audio_model.layernorm.bias", + r"modality_heads\.audio\.2\.weight": "audio_projection.weight", +} + + +def rename_encoder_layers(config, modality): + rename_keys = {} + # fmt: off + # Patterns for the keys + key_patterns = [ + (r"attn\.in_proj_weight", f"{modality}_model.encoder.layers.{{layer_idx}}.self_attn.qkv_proj.weight"), + (r"attn\.in_proj_bias", f"{modality}_model.encoder.layers.{{layer_idx}}.self_attn.qkv_proj.bias"), + (r"attn\.out_proj\.weight", f"{modality}_model.encoder.layers.{{layer_idx}}.self_attn.out_proj.weight"), + (r"attn\.out_proj\.bias", f"{modality}_model.encoder.layers.{{layer_idx}}.self_attn.out_proj.bias"), + (r"norm_1\.weight", f"{modality}_model.encoder.layers.{{layer_idx}}.layernorm_before.weight"), + (r"norm_1\.bias", f"{modality}_model.encoder.layers.{{layer_idx}}.layernorm_before.bias"), + (r"mlp\.fc1\.weight", f"{modality}_model.encoder.layers.{{layer_idx}}.mlp.fc1.weight"), + (r"mlp\.fc1\.bias", f"{modality}_model.encoder.layers.{{layer_idx}}.mlp.fc1.bias"), + (r"mlp\.fc2\.weight", f"{modality}_model.encoder.layers.{{layer_idx}}.mlp.fc2.weight"), + (r"mlp\.fc2\.bias", f"{modality}_model.encoder.layers.{{layer_idx}}.mlp.fc2.bias"), + (r"norm_2\.weight", f"{modality}_model.encoder.layers.{{layer_idx}}.layernorm_after.weight"), + (r"norm_2\.bias", f"{modality}_model.encoder.layers.{{layer_idx}}.layernorm_after.bias"), + ] + + for layer_idx in range(config.num_hidden_layers): + for old_pattern, new_pattern in key_patterns: + rename_keys[f"modality_trunks.{modality}.blocks.{layer_idx}.{old_pattern}"] = new_pattern.format(layer_idx=layer_idx) + + if config.add_kv_bias: + rename_keys[f"modality_trunks.{modality}.blocks.{layer_idx}.attn.bias_k"] = f"{modality}_model.encoder.layers.{layer_idx}.self_attn.k_bias" + rename_keys[f"modality_trunks.{modality}.blocks.{layer_idx}.attn.bias_v"] = f"{modality}_model.encoder.layers.{layer_idx}.self_attn.v_bias" + + # fmt: on + + return rename_keys + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + vision_config = config.vision_config + text_config = config.text_config + audio_config = config.audio_config + + rename_keys = {} + + # fmt: off + + rename_keys.update(ORIGINAL_TO_CONVERTED_KEY_MAPPING) + + rename_keys.update( + rename_encoder_layers(vision_config, "vision") + ) + + rename_keys.update( + rename_encoder_layers(text_config, "text") + ) + + rename_keys.update( + rename_encoder_layers(audio_config, "audio") + ) + # fmt: on + + return rename_keys + + +def rename_model_keys(dct, rename_keys): + renamed_dict = {} + + for key, value in dct.items(): + new_key = key + for pattern, new_pattern in rename_keys.items(): + new_key = re.sub(pattern, new_pattern, new_key) + renamed_dict[new_key] = value + + return renamed_dict + + +def reshape_text_position_embeddings(state_dict): + # Need to convert from (1, contexc_length, hidden_size) -> (context_length, hidden_size) + position_embeddings = state_dict["text_model.embeddings.position_embedding.weight"] + state_dict["text_model.embeddings.position_embedding.weight"] = position_embeddings.squeeze(0) + + return state_dict + + +@torch.no_grad() +def convert_imagebind_checkpoint(args): + model_name = args.model_name + pytorch_dump_folder_path = args.pytorch_dump_folder_path + push_to_hub = args.push_to_hub + hub_repo_path = args.hub_repo_path + + config = ImageBindConfig() + + # Load original checkpoint + checkpoint_url = "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth" + original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + + # Rename keys + new_state_dict = original_state_dict.copy() + rename_keys = create_rename_keys(config) + + new_state_dict = rename_model_keys(new_state_dict, rename_keys) + + reshape_text_position_embeddings(new_state_dict) + + # Load HF model + model = ImageBindModel(config) + + model.eval() + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + print("Missing keys:", missing_keys) + print("") + print("Unexpected keys:", unexpected_keys) + + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + image_processor = ImageBindImageProcessor() + feature_extractor = ImageBindFeatureExtractor() + processor = ImageBindProcessor(image_processor, tokenizer, feature_extractor) + + if pytorch_dump_folder_path is not None: + print(f"Saving model and processor for {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing model and processor for {model_name} to hub at {hub_repo_path}") + model.push_to_hub(hub_repo_path) + processor.push_to_hub(hub_repo_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model-name", + default="imagebind-huge", + type=str, + choices=["imagebind-huge"], + help="Name of the ImageBind model you'd like to convert.", + ) + parser.add_argument( + "--pytorch-dump-folder-path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push-to-hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + parser.add_argument( + "--hub-repo-path", default=None, type=str, help="Path of the repository to push the model on the 🤗 hub." + ) + + args = parser.parse_args() + convert_imagebind_checkpoint(args) diff --git a/src/transformers/models/imagebind/feature_extraction_imagebind.py b/src/transformers/models/imagebind/feature_extraction_imagebind.py new file mode 100644 index 000000000000..8311ab7673c8 --- /dev/null +++ b/src/transformers/models/imagebind/feature_extraction_imagebind.py @@ -0,0 +1,417 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for ImageBind.""" + +from fractions import Fraction +from typing import List, Optional, Tuple, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import TensorType, is_speech_available, is_torch_available, logging + + +if is_speech_available(): + import torchaudio.compliance.kaldi as ta_kaldi + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +def valid_batched_clipped_audio(raw_speech): + """ + Determines whether raw mono-channel audio input (or any other 1D data) is batched and clipped. The following + conditions will be recognized as valid audio: + + - unbatched: `List[float]`, `np.ndarray` (`ndim=1`) + - batched: `List[List[float]]`, `List[np.ndarray]` (`ndim=1`), `np.ndarray` (`ndim=2`) + - batched and clipped: `List[List[List[float]]]`, `List[List[np.ndarray]]` (`ndim=1`), List[np.ndarray] (`ndim=2`), np.ndarray (`ndim=3`) + """ + if isinstance(raw_speech, np.ndarray): + return 1 <= raw_speech.ndim <= 3 + if isinstance(raw_speech, (list, tuple)): + first_elem = raw_speech[0] + if isinstance(first_elem, float): + return True + if isinstance(first_elem, np.ndarray): + return 1 <= first_elem.ndim <= 2 + if isinstance(first_elem, (list, tuple)): + second_elem = first_elem[0] + if isinstance(second_elem, (float, np.ndarray)): + return True + if isinstance(second_elem, (list, tuple)): + return isinstance(second_elem[0], float) + + return False + + +def convert_raw_speech_to_numpy_array(raw_speech): + """If not already in numpy array format, convert raw_speech to a numpy array.""" + if isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], float): + raw_speech = [[np.asarray(raw_speech, dtype=np.float32)]] + elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], (list, tuple)): + if isinstance(raw_speech[0][0], float): + # List[List[float]] + raw_speech = [[np.asarray(audio, dtype=np.float32)] for audio in raw_speech] + elif isinstance(raw_speech[0][0], (list, tuple)): + # List[List[List[float]]] + raw_speech = [[np.asarray(audio, dtype=np.float32) for audio in clip] for clip in raw_speech] + + return raw_speech + + +def batch_and_clip_ndarray(array, data_dim=1, dtype=np.float32): + """ + Turns a possibly nested list of np.ndarrays into a batched and clipped output of type `List[List[np.ndarray]]`. + """ + if ( + isinstance(array, (list, tuple)) + and isinstance(array[0], (list, tuple)) + and isinstance(array[0][0], np.ndarray) + ): + if array[0][0].ndim == data_dim: + return [[base_array.astype(dtype=dtype) for base_array in clips] for clips in array] + else: + raise ValueError( + f"`For List[List[np.ndarray]]` inputs the internal `np.ndarray`s are expected to have dimension" + f" {data_dim} but got dimension {array[0][0].ndim}" + ) + elif isinstance(array, (list, tuple)) and isinstance(array[0], np.ndarray): + if array[0].ndim == data_dim + 1: + return [[np.asarray(base_array, dtype=dtype) for base_array in clips] for clips in array] + elif array[0].ndim == data_dim: + return [[base_array.astype(dtype=dtype)] for base_array in array] + else: + raise ValueError( + f"For `List[np.ndarray]` inputs the internal `np.ndarray`s are expected to have dimension" + f" {data_dim} or {data_dim + 1} but got dimension {array[0].ndim}" + ) + elif isinstance(array, np.ndarray): + array = array.astype(dtype=dtype) + if array.ndim == data_dim + 2: + return [list(clips) for clips in array] + elif array.ndim == data_dim + 1: + return [[clip] for clip in array] + elif array.ndim == data_dim: + return [[array]] + else: + raise ValueError( + f"`np.ndarray` inputs are expected to have dimension in" + f" `[{data_dim}, {data_dim + 1}, {data_dim + 2}]` but instead got {array.ndim}" + ) + else: + raise ValueError(f"Could not make batched and clipped audio from {array}") + + +# Adapted from https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/clip_sampling.py#L346 +def uniform_chunk_sampling( + total_duration: float, chunk_duration: float, num_chunks: int +) -> List[Tuple[Fraction, Fraction]]: + """ + Uniformly sample `num_chunks` chunks of duration `chunk_duration` from an audio/video of total duration `total_duration`. + + Args: + total_duration (`float`): s + Total duration of the audio/video. + chunk_duration (`float`): + Duration of each chunk. + num_chunks (`int`): + Number of chunks to sample. + """ + chunk_duration_fraction = Fraction(chunk_duration) + max_possible_clip_start = Fraction(max(total_duration - chunk_duration, 0)) + uniform_clip = Fraction(max_possible_clip_start / max(num_chunks - 1, 1)) + + result = [] + for clip_index in range(num_chunks): + clip_start_sec = uniform_clip * clip_index + clip_end_sec = clip_start_sec + chunk_duration_fraction + result.append((clip_start_sec, clip_end_sec)) + + return result + + +class ImageBindFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Audio Spectrogram Transformer (AST) feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech using TorchAudio, pads/truncates them to a fixed + length and normalizes them using a mean and standard deviation. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + num_mel_bins (`int`, *optional*, defaults to 128): + Number of Mel-frequency bins. + max_length (`int`, *optional*, defaults to 204): + Maximum length to which to pad/truncate the extracted features. + padding_value (`float`, *optional*, defaults to 0.0): + The value to pad with when applying the padding strategy defined by the `padding` argument to + [ImageBindAudioFeatureExtractor.__call__`]. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the log-Mel features using `mean` and `std`. + mean (`float`, *optional*, defaults to -4.268): + The mean value used to normalize the log-Mel features. Uses the AudioSet mean by default. + std (`float`, *optional*, defaults to 9.138): + The standard deviation value used to normalize the log-Mel features. Uses the AudioSet standard deviation + by default. + do_chunk (`bool`, *optional*, defaults to `True`): + Whether or not to sample multiple chunks from the input audio. If `False`, the entire audio will be used. + chunk_duration (`float`, *optional*, defaults to 2.0): + The duration of each chunk in seconds. + num_chunks (`int`, *optional*, defaults to 3): + The number of chunks to sample from the input audio. + """ + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size=1, + sampling_rate=16000, + num_mel_bins=128, + max_length=204, + padding_value=0.0, + do_normalize=True, + mean=-4.268, + std=9.138, + do_chunk=True, + chunk_duration=2.0, + num_chunks=3, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.num_mel_bins = num_mel_bins + self.max_length = max_length + self.do_normalize = do_normalize + self.mean = mean + self.std = std + self.do_chunk = do_chunk + self.chunk_duration = chunk_duration + self.num_chunks = num_chunks + + if not is_speech_available(): + mel_filters = mel_filter_bank( + num_frequency_bins=256, + num_mel_filters=self.num_mel_bins, + min_frequency=20, + max_frequency=sampling_rate // 2, + sampling_rate=sampling_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) + self.window = window_function(400, "hann", periodic=False) + + def _extract_fbank_features( + self, + waveform: np.ndarray, + max_length: int, + ) -> np.ndarray: + """ + Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs + and hence the waveform should not be normalized before feature extraction. + """ + # Mean center the waveform + waveform -= waveform.mean() + + if is_speech_available(): + waveform = torch.from_numpy(waveform).unsqueeze(0) + fbank = ta_kaldi.fbank( + waveform, + sample_frequency=self.sampling_rate, + window_type="hanning", + num_mel_bins=self.num_mel_bins, + ) + else: + if waveform.size > 0: + waveform = np.squeeze(waveform) + else: + # Handle the empty waveform case + raise ValueError("Empty waveform input") + + fbank = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + + fbank = torch.from_numpy(fbank) + + # Convert to [mel_bins, num_frames] shape + fbank = fbank.transpose(0, 1) + # pad to max_length + n_frames = fbank.size(1) + difference = max_length - n_frames + + if abs(difference) / n_frames > 0.2: + logger.warning_once( + f"Large padding or truncation for {tuple(waveform.shape)} waveform with {n_frames} frames and {max_length} max_length." + ) + + # pad or truncate + if difference > 0: + fbank = torch.nn.functional.pad(fbank, (0, difference), mode="constant", value=0) + elif difference < 0: + fbank = fbank[:, 0:max_length] + # Add 1 channel so that dimension of fbank is [1, num_mel_bins, num_frames] + fbank = fbank.unsqueeze(0) + fbank = fbank.numpy() + + return fbank + + def normalize(self, input_values: np.ndarray, mean: float, std: float) -> np.ndarray: + return (input_values - (mean)) / (std) + + def chunk(self, raw_speech: np.ndarray, chunk_duration: float, num_chunks: int) -> List[np.ndarray]: + audio_duration = raw_speech.shape[0] / self.sampling_rate + if chunk_duration > audio_duration: + logger.warning_once( + "Chunk duration is greater than audio duration. Chunks will be repeated, consider adjusting either `chunk_duration` or `num_chunks`" + "to avoid unnecessary memory/compute usage." + ) + all_clips_timepoints = uniform_chunk_sampling(audio_duration, chunk_duration, num_chunks) + + all_clips = [] + for clip_timepoints in all_clips_timepoints: + waveform_clip = raw_speech[ + int(clip_timepoints[0] * self.sampling_rate) : int(clip_timepoints[1] * self.sampling_rate) + ] + all_clips.append(waveform_clip) + + return all_clips + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]], List[List[List[float]]]], + sampling_rate: Optional[int] = None, + do_normalize: Optional[bool] = None, + mean: Optional[float] = None, + std: Optional[float] = None, + do_chunk: Optional[bool] = None, + chunk_duration: Optional[float] = None, + num_chunks: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of numpy + arrays or a (possibly nested) list of float values. The supported input types are as follows: + + - unbatched: `List[float]`, `np.ndarray` (`ndim=1`) + - batched: `List[List[float]]`, `List[np.ndarray]` (`ndim=1`), `np.ndarray` (`ndim=2`) + - batched with clips: `List[List[List[float]]]`, `List[List[np.ndarray]]` (`ndim=1`), `List[np.ndarray]` (`ndim=2`), np.ndarray (`ndim=3`) + + The input will always be interpreted as mono channel audio, not stereo, i.e. a single float per timestep. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + do_normalize (`bool`, *optional*, defaults `self.do_normalize`): + Whether or not to normalize the log-Mel features. + mean (`float`, *optional*, defaults `self.mean`): + The mean value used to normalize the log-Mel features. + std (`float`, *optional*, defaults `self.std`): + The standard deviation value used to normalize the log-Mel features. + do_chunk (`bool`, *optional*, defaults `self.do_chunk`): + Whether or not to sample multiple chunks from the input audio. If `False`, the entire audio will be used. + chunk_duration (`float`, *optional*, defaults `self.chunk_duration`): + The duration of each chunk in seconds. + num_chunks (`int`, *optional*, defaults `self.num_chunks`): + The number of chunks to sample from the input audio. If audio duration is less than `chunk_duration` * `num_chunks`, + chunks will overlap to cover the entire audio. If `chunk_duration` is greater than audio duration, the + chunks will be repeated until `num_chunks` is reached. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + "It is strongly recommended to pass the `sampling_rate` argument to this function. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if not valid_batched_clipped_audio(raw_speech): + raise ValueError( + f"Only unbatched, batched, and batched and clipped mono-channel audio is supported for input to {self}" + ) + + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + mean = mean if mean is not None else self.mean + std = std if std is not None else self.std + do_chunk = do_chunk if do_chunk is not None else self.do_chunk + chunk_duration = chunk_duration if chunk_duration is not None else self.chunk_duration + num_chunks = num_chunks if num_chunks is not None else self.num_chunks + + raw_speech = convert_raw_speech_to_numpy_array(raw_speech) + raw_speech = batch_and_clip_ndarray(raw_speech, data_dim=1, dtype=np.float32) + + if do_chunk and len(raw_speech[0]) == 1: + raw_speech = [self.chunk(audio[0], chunk_duration, num_chunks) for audio in raw_speech] + + features = [ + [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in clip] + for clip in raw_speech + ] + + features = np.asarray(features) + padded_inputs = BatchFeature({"input_features": features}) + + if do_normalize: + padded_inputs["input_features"] = [ + [self.normalize(feature, mean, std) for feature in clip] for clip in padded_inputs["input_features"] + ] + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["ImageBindFeatureExtractor"] diff --git a/src/transformers/models/imagebind/image_processing_imagebind.py b/src/transformers/models/imagebind/image_processing_imagebind.py new file mode 100644 index 000000000000..5402256b6db1 --- /dev/null +++ b/src/transformers/models/imagebind/image_processing_imagebind.py @@ -0,0 +1,827 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for ImageBind.""" + +import math +from fractions import Fraction +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, + to_pil_image, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + VideoInput, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends +from ...utils.deprecation import deprecate_kwarg + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + +if is_torch_available(): + import torch + + +# Copied from transformers.models.video_llava.image_processing_video_llava.make_batched_videos +def make_batched_videos(videos) -> List[VideoInput]: + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + return videos + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + if isinstance(videos[0], PIL.Image.Image): + return [videos] + elif len(videos[0].shape) == 4: + return [list(video) for video in videos] + + elif is_valid_image(videos) and len(videos.shape) == 4: + return [list(videos)] + + raise ValueError(f"Could not make batched video from {videos}") + + +# Copied from transformers.models.imagebind.feature_extraction_imagebind.uniform_chunk_sampling +def uniform_chunk_sampling( + total_duration: float, chunk_duration: float, num_chunks: int +) -> List[Tuple[Fraction, Fraction]]: + """ + Uniformly sample `num_chunks` chunks of duration `chunk_duration` from an audio/video of total duration `total_duration`. + + Args: + total_duration (float): Total duration of the audio/video. + chunk_duration (float): Duration of each chunk. + num_chunks (int): Number of chunks to sample. + + Returns: + List[Tuple[float, float]]: List of tuples where each tuple contains the start and end time of a chunk. + """ + chunk_duration_fraction = Fraction(chunk_duration) + max_possible_clip_start = Fraction(max(total_duration - chunk_duration, 0)) + uniform_clip = Fraction(max_possible_clip_start / max(num_chunks - 1, 1)) + + result = [] + for clip_index in range(num_chunks): + clip_start_sec = uniform_clip * clip_index + clip_end_sec = clip_start_sec + chunk_duration_fraction + result.append((clip_start_sec, clip_end_sec)) + + return result + + +# Adapted from https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 +def uniform_temporal_subsample(video: VideoInput, num_samples: int) -> VideoInput: + """ + Uniformly subsamples num_samples indices from the temporal dimension of the video. + When num_samples is larger than the size of temporal dimension of the video, it + will sample frames based on nearest neighbor interpolation. + + Args: + video (`VideoInput`): + Video to subsample. + num_samples (`int`): + Number of frames to sample. + """ + num_frames = len(video) + + # Sample by nearest neighbor interpolation if num_samples > t. + indices = np.linspace(0, num_frames - 1, num_samples) + indices = np.clip(indices, 0, num_frames - 1).astype(int) + + return [video[i] for i in indices] + + +# Adapted from https://github.com/facebookresearch/pytorchvideo/blob/1fadaef40dd393ca09680f55582399f4679fc9b7/pytorchvideo/transforms/functional.py#L92 +def video_resize( + frames: List[np.ndarray], + size: Tuple[int, int] = 224, + resampling: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Determines the shorter spatial dim of the video (i.e. width or height) and scales + it to the given size. To maintain aspect ratio, the longer side is then scaled + accordingly. + Args: + image (np.ndarray): A video tensor of shape (C, T, H, W) and type numpy.float32. + size (int): The size the shorter side is scaled to. + resample (str): Algorithm used for upsampling, + options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + Returns: + An image-like numpy array with scaled spatial dims. + """ # noqa + requires_backends(video_resize, ["torch"]) + + # channel-first + frames = [ + to_channel_dimension_format(frame, ChannelDimension.FIRST, input_channel_dim=input_data_format) + for frame in frames + ] + # stack, to torch and reshape to num_channels, num_frames, height, width + video = np.stack(frames) + video = torch.from_numpy(video).contiguous() + + data_format = input_data_format if data_format is None else data_format + video = torch.nn.functional.interpolate(video, size=size, mode=resampling.name.lower(), align_corners=False) + frames = list(video.numpy()) + frames = [ + to_channel_dimension_format(frame, data_format, input_channel_dim=ChannelDimension.FIRST) for frame in frames + ] + + return frames + + +# Same as in image_transforms.py but taking offsets like int(math.ceil((orig_height - crop_height) / 2)) +@deprecate_kwarg("return_numpy", version="5.0") +def modified_center_crop( + image: np.ndarray, + size: Tuple[int, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + return_numpy: Optional[bool] = None, +) -> np.ndarray: + """ + Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to + the size given, it will be padded (so the returned result will always be of size `size`). + + Args: + image (`np.ndarray`): + The image to crop. + size (`Tuple[int, int]`): + The target size for the cropped image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + return_numpy (`bool`, *optional*): + Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the + previous ImageFeatureExtractionMixin method. + - Unset: will return the same type as the input image. + - `True`: will return a numpy array. + - `False`: will return a `PIL.Image.Image` object. + Returns: + `np.ndarray`: The cropped image. + """ + requires_backends(modified_center_crop, ["vision"]) + + return_numpy = True if return_numpy is None else return_numpy + + if not isinstance(image, np.ndarray): + raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") + + if not isinstance(size, Iterable) or len(size) != 2: + raise ValueError("size must have 2 elements representing the height and width of the output image") + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + output_data_format = data_format if data_format is not None else input_data_format + + # We perform the crop in (C, H, W) format and then convert to the output format + image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format) + + orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST) + crop_height, crop_width = size + crop_height, crop_width = int(crop_height), int(crop_width) + + # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result. + top = int(math.ceil((orig_height - crop_height) / 2)) + bottom = top + crop_height + # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result. + left = int(math.ceil((orig_width - crop_width) / 2)) + right = left + crop_width + + # Check if cropped area is within image boundaries + if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width: + image = image[..., top:bottom, left:right] + image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST) + return image + + # Otherwise, we may need to pad if the image is too small. Oh joy... + new_height = max(crop_height, orig_height) + new_width = max(crop_width, orig_width) + new_shape = image.shape[:-2] + (new_height, new_width) + new_image = np.zeros_like(image, shape=new_shape) + + # If the image is too small, pad it with zeros + top_pad = math.ceil((new_height - orig_height) / 2) + bottom_pad = top_pad + orig_height + left_pad = math.ceil((new_width - orig_width) / 2) + right_pad = left_pad + orig_width + new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image + + top += top_pad + bottom += top_pad + left += left_pad + right += left_pad + + new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)] + new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST) + + if not return_numpy: + new_image = to_pil_image(new_image) + + return new_image + + +class ImageBindImageProcessor(BaseImageProcessor): + r""" + Constructs an ImageBind image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_chunk (`bool`, *optional*, defaults to `True`): + Whether to chunk the video into multiple clips. + chunk_duration (`float`, *optional*, defaults to 2.0): + Duration of each chunk in seconds. + num_chunks (`int`, *optional*, defaults to 5): + Number of chunks to sample. + num_frames_per_chunk (`int`, *optional*, defaults to 2): + Number of frames to sample per chunk. + fps (`int`, *optional*, defaults to 30): + Frame rate of the video. It's assumed that all videos have the same frame rate. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + do_chunk: bool = True, + chunk_duration: float = 2.0, + num_chunks: int = 5, + num_frames_per_chunk: int = 2, + fps: int = 30, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self.do_chunk = do_chunk + self.chunk_duration = chunk_duration + self.num_chunks = num_chunks + self.num_frames_per_chunk = num_frames_per_chunk + self.fps = fps + self._valid_processor_keys = [ + "images", + "videos", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "do_chunk", + "chunk_duration", + "num_chunks", + "num_frames_per_chunk", + "fps", + "return_tensors", + "data_format", + "input_data_format", + ] + + def video_resize( + self, + frames: List[np.ndarray], + size: Dict[str, int], + resampling: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> List[np.ndarray]: + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + frames[0], + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + + return video_resize( + frames=frames, + size=output_size, + resampling=resampling, + data_format=data_format, + input_data_format=input_data_format, + ) + + # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def chunk( + self, video: VideoInput, fps: int, chunk_duration: float, num_chunks: int, num_frames_per_chunk: int + ) -> List[VideoInput]: + """ + Uniformly sample `num_chunks` chunks of duration `chunk_duration` from a video. + + Args: + video (`VideoInput`): + Video to chunk. + fps (`int`): + Frame rate of the video + chunk_duration (`float`): + Duration of each chunk. + num_chunks (`int`): + Number of chunks to sample. + num_frames_per_chunk (`int`): + Number of frames to sample per chunk. + """ + video_duration = len(video) / fps + if video_duration < chunk_duration: + logger.warning_once( + "Chunk duration is greater than audio duration. Chunks will be repeated, consider adjusting either `chunk_duration` or `num_chunks`" + "to avoid unnecessary memory/compute usage." + ) + + all_clips_timepoints = uniform_chunk_sampling(video_duration, chunk_duration, num_chunks) + + all_clips = [] + for clip_timepoints in all_clips_timepoints: + video_clip = video[math.ceil(clip_timepoints[0] * fps) : math.ceil(clip_timepoints[1] * fps)] + video_clip = uniform_temporal_subsample(video_clip, num_samples=num_frames_per_chunk) + all_clips.append(video_clip) + + return all_clips + + def center_crop( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along + any edge, the image is padded with 0's and then center cropped. + + Args: + image (`np.ndarray`): + Image to center crop. + size (`Dict[str, int]`): + Size of the output image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") + return modified_center_crop( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _preprocess_image( + self, + images: ImageInput, + is_video: bool = False, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize and is_video: + images = self.video_resize( + frames=images, size=size, resampling=resample, input_data_format=input_data_format + ) + + all_images = [] + for image in images: + if do_resize and not is_video: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + return images + + def preprocess( + self, + images: Optional[ImageInput] = None, + videos: Optional[VideoInput] = None, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + do_chunk: bool = None, + chunk_duration: float = None, + num_chunks: int = None, + num_frames_per_chunk: int = None, + fps: int = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`, *optional*): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. Either `images` or + `videos` must be provided. + videos (`VideoInput`, *optional*): + Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If + passing in videos with pixel values between 0 and 1, set `do_rescale=False`. Either `images` or + `videos` must be provided. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_chunk (`bool`, *optional*, defaults to `self.do_chunk`): + Whether to chunk the video into multiple clips. + chunk_duration (`float`, *optional*, defaults to `self.chunk_duration`): + Duration of each chunk in seconds. + num_chunks (`int`, *optional*, defaults to `self.num_chunks`): + Number of chunks to sample. + num_frames_per_chunk (`int`, *optional*, defaults to `self.num_frames_per_chunk`): + Number of frames to sample per chunk. + fps (`int`, *optional*, defaults to `self.fps`): + Frame rate of the video. It's assumed that all videos have the same frame rate. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + if images is None and videos is None: + raise ValueError("Either `images` or `videos` must be provided.") + + if images is not None and videos is not None: + raise ValueError("Only one of `images` or `videos` can be provided.") + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_chunk = do_chunk if do_chunk is not None else self.do_chunk + chunk_duration = chunk_duration if chunk_duration is not None else self.chunk_duration + num_chunks = num_chunks if num_chunks is not None else self.num_chunks + num_frames_per_chunk = num_frames_per_chunk if num_frames_per_chunk is not None else self.num_frames_per_chunk + fps = fps if fps is not None else self.fps + + if images is not None: + images = make_list_of_images(images) + if videos is not None: + videos = make_batched_videos(videos) + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + if (videos is not None and not valid_images(videos)) or (images is not None and not valid_images(images)): + raise ValueError( + "Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if images is not None: + pixel_values = self._preprocess_image( + images=images, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + else: + pixel_values = [] + for video in videos: + if do_chunk: + clips = self.chunk( + video=video, + fps=fps, + chunk_duration=chunk_duration, + num_chunks=num_chunks, + num_frames_per_chunk=num_frames_per_chunk, + ) + + _pixel_values = [ + self._preprocess_image( + images=clip, + is_video=True, + do_resize=do_resize, + size=size, + resample=PILImageResampling.BILINEAR, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for clip in clips + ] + else: + _pixel_values = [ + self._preprocess_image( + images=video, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + ] + + # Avoid List[List[List[np.ndarray]]] for performance reasons + _pixel_values = np.stack(_pixel_values) + # Make it shape (num_chunks, num_channels, num_frames_per_chunk, height, width) + _pixel_values = np.swapaxes(_pixel_values, 1, 2) + pixel_values.append(_pixel_values) + + return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors) + + +__all__ = ["ImageBindImageProcessor"] diff --git a/src/transformers/models/imagebind/modeling_imagebind.py b/src/transformers/models/imagebind/modeling_imagebind.py new file mode 100644 index 000000000000..d8860755b66b --- /dev/null +++ b/src/transformers/models/imagebind/modeling_imagebind.py @@ -0,0 +1,2058 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ImageBind model.""" + +import collections.abc +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_imagebind import ( + ImageBindAudioConfig, + ImageBindConfig, + ImageBindTextConfig, + ImageBindVisionConfig, +) + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->imagebind +def imagebind_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class ImageBindTransformerOutput(ModelOutput): + """ + The output class for ImageBind*Transformer models. This is [`BaseModelOutputWithPooling`] with an additional + `num_clips` field for modalities which are organized into clips as well as batches (vision, audio). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ImageBindTextModelOutput(ModelOutput): + """ + Base class for text model's outputs. This is [`CLIPTextModelOutput`] that also contains a pooling of the last hidden states + or normalized embeddings. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + normalized_text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`): + The normalized text embeddings obtained by applying the projection layer to the pooler_output, then + applying L2 normalization and scaling the logits. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + normalized_text_embeds: Optional[torch.FloatTensor] = None + + +@dataclass +class ImageBindVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs, This is [`ClipVisionModelOutput`] that also contains image embeddings of the pooling of the + last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + normalized_image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`): + The normalized image embeddings obtained by applying the projection layer to the pooler_output, then + applying L2 normalization and scaling the logits. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + normalized_image_embeds: Optional[torch.FloatTensor] = None + + +@dataclass +class ImageBindAudioModelOutput(ModelOutput): + """ + ClapAudio model output to mimic the output of the original implementation. This is [`CLAPAudioModelOutput`] that also contains a pooling of the last hidden states + or normalized embeddings. + + Args: + audio_embeds (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + The Audio embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + normalized_audio_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`): + The normalized audio embeddings obtained by applying the projection layer to the pooler_output, then + applying L2 normalization and scaling the logits. + """ + + audio_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + normalized_audio_embeds: Optional[torch.FloatTensor] = None + + +@dataclass +class ImageBindOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + logits_per_audio:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `audio_embeds` and `image_embeds`. This represents the audio-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The normalized text embeddings obtained by applying the projection layer to the pooled output of [`ImageBindTextModel`], then applying L2 normalization and logit scaling. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The normalized image embeddings obtained by applying the projection layer to the pooled output of [`ImageBindVisionModel`], then applying L2 normalization and logit scaling. + audio_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The normalized audio embeddings obtained by applying the projection layer to the pooled output of [`ImageBindAudioModel`], then applying L2 normalization and logit scaling. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`ImageBindTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`ImageBindVisionModel`]. + audio_model_output(`BaseModelOutputWithPooling`): + The output of the [`ImageBindAudioModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + logits_per_audio: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + audio_embeds: torch.FloatTensor = None + vision_model_output: BaseModelOutputWithPooling = None + text_model_output: BaseModelOutputWithPooling = None + audio_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + fields_to_exclude = [ + "text_model_output", + "vision_model_output", + "audio_model_output", + ] + return tuple(self[k] if k not in fields_to_exclude else getattr(self, k).to_tuple() for k in self.keys()) + + +class ImageBindGenericPatchEmbedding(nn.Module): + """Generic Patch Embedding class that can be used for Vision (image/video), Audio, Depth, Thermal modalities.""" + + def __init__( + self, + config: Union[ImageBindVisionConfig, ImageBindAudioConfig], + projection: nn.Module, + use_layernorm: bool = False, + ): + super().__init__() + + image_size = config.feature_size + + self.image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + self.num_channels = config.num_channels + + self.projection = projection + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if use_layernorm else None + + def forward(self, input_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + if input_values.ndim not in [4, 5]: + raise ValueError(f"Input tensor shape should have length 4 or 5 but got {input_values.ndim}.") + + _, num_channels, *spatial_shape = input_values.shape + height, width = spatial_shape[-2:] + + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + + embeddings = self.projection(input_values).flatten(2).transpose(1, 2) + if self.layernorm is not None: + embeddings = self.layernorm(embeddings) + + return embeddings + + +class ImageBindVisionEmbeddings(nn.Module): + def __init__(self, config: ImageBindVisionConfig): + super().__init__() + self.config = config + self.num_frames = config.num_frames + num_patches = (config.image_size // config.patch_size) ** 2 + + projection = nn.Conv3d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=(config.num_frames, config.patch_size, config.patch_size), + stride=(config.num_frames, config.patch_size, config.patch_size), + bias=False, + ) + self.patch_embedding = ImageBindGenericPatchEmbedding( + config=config, projection=projection, use_layernorm=False + ) + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def image_to_video(self, pixel_values: torch.FloatTensor, time_dim: int = 2, num_frames: int = 2): + """ + Maps 4-dim image tensors of shape (B, C, H, W) to 5-dim video tensors, possibly repeating the image along the + time dimension. For example, if `time_dim == 1`, RGB images of shape (B, C, H, W) will be transformed to + video of shape (B, 1, C, H, W), and then the image will be repeated along the time dimension `num_frames` to get + shape (B, N, C, H, W). + """ + if pixel_values.ndim not in [4, 5]: + raise ValueError( + f"The input `image` tensor should be 4- or 5-dimensional but has {pixel_values.ndim} dimensions." + ) + + # Add time dimension at specified dim index + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(time_dim) + + # Repeat image across the time dimension num_frames. + if pixel_values.shape[time_dim] == 1: + new_shape = [1] * len(pixel_values.shape) + new_shape[time_dim] = num_frames + pixel_values = pixel_values.repeat(new_shape) + + return pixel_values + + def forward( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + pixel_values = self.image_to_video(pixel_values, num_frames=self.num_frames) + batch_size, num_channels, num_frames, height, width = pixel_values.shape + + embeddings = self.patch_embedding(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + return embeddings + + +class ImageBindAudioEmbeddings(nn.Module): + def __init__(self, config: ImageBindAudioConfig): + super().__init__() + self.config = config + + num_patches_height = int((config.num_mel_bins - config.patch_size) / config.stride + 1) + num_patches_width = int((config.target_len - config.patch_size) / config.stride + 1) + num_patches = num_patches_height * num_patches_width + + proj = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.stride, + bias=False, + ) + + self.patch_embedding = ImageBindGenericPatchEmbedding(config=config, projection=proj, use_layernorm=True) + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + + def forward(self, input_features: torch.FloatTensor) -> torch.Tensor: + embeddings = self.patch_embedding(input_features, interpolate_pos_encoding=False) + + cls_tokens = self.cls_token.expand(embeddings.shape[0], -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # Could also add interpolation of position encoding as well + embeddings = embeddings + self.position_embeddings + + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->ImageBind +class ImageBindTextEmbeddings(nn.Module): + def __init__(self, config: ImageBindTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class ImageBindAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper. This is [`CLIPAttention`] with key and value biases""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + self.add_kv_bias = config.add_kv_bias + + # Create bias parameters for key and value sequences. + if self.add_kv_bias: + self.k_bias = nn.Parameter(torch.empty((1, 1, self.embed_dim))) + self.v_bias = nn.Parameter(torch.empty((1, 1, self.embed_dim))) + else: + self.k_bias = None + self.v_bias = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_len, embed_dim = hidden_states.size() + + qkv = self.qkv_proj(hidden_states).reshape(batch_size, seq_len, 3, -1).permute(2, 0, 1, 3) + query_states, key_states, value_states = qkv.unbind(0) + + query_states = query_states * self.scale + + # Add key/value biases if necessary + if self.add_kv_bias: + # Repeat bias along batch dimension (first) + key_states = torch.cat([key_states, self.k_bias.repeat(batch_size, 1, 1)], dim=1) + value_states = torch.cat([value_states, self.v_bias.repeat(batch_size, 1, 1)], dim=1) + + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = query_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous().view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, seq_len, src_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, seq_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, seq_len, src_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, seq_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, seq_len, src_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, seq_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, seq_len, src_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, seq_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, seq_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, seq_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, seq_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, seq_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class ImageBindMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + intermediate_size = config.intermediate_size + + self.fc1 = nn.Linear(config.hidden_size, intermediate_size) + self.fc2 = nn.Linear(intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ImageBind +class ImageBindDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class ImageBindEncoderLayer(nn.Module): + """This is [`CLIPEncoderLayer`] with DropPath layer after each residual subblock (attention, feedforward)""" + + def __init__( + self, + config: Union[ImageBindVisionConfig, ImageBindTextConfig, ImageBindAudioConfig], + drop_path_rate: float = 0.0, + ): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = ImageBindAttention(config) + self.layernorm_before = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = ImageBindMlp(config) + self.layernorm_after = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + if drop_path_rate > 0.0: + self.drop_path = ImageBindDropPath(drop_path_rate) + else: + self.drop_path = nn.Identity() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.drop_path(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layernorm_after(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.drop_path(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class ImageBindPostProcessor(nn.Module): + """ + Post-processes ImageBind embeddings by using a normalize layer followed by an optional logit scaling layer. + + Args: + config (Union[ImageBindTextConfig, ImageBindVisionConfig,ImageBindAudioConfig]): A configuration object that contains + initialization values for logit scaling. + dim (int, optional): The dimension along which to normalize the logits. Default is -1, which indicates the last dimension. + max_logit_scale (float, optional): The maximum value to which the logit scale can be clipped. Default is 100. + """ + + def __init__( + self, + config, + dim: int = -1, + max_logit_scale: float = 100, + ): + super().__init__() + self.dim = dim + self.scale_logits = config.logit_scale_init_value is not None + + if self.scale_logits: + self.logit_scale_init = config.logit_scale_init_value + self.max_logit_scale = max_logit_scale + self.learnable = config.learnable_logit_scale + + log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) + if self.learnable: + self.log_logit_scale = nn.Parameter(log_logit_scale) + else: + self.register_buffer("log_logit_scale", log_logit_scale) + + def forward(self, logits: torch.FloatTensor) -> torch.FloatTensor: + logits = nn.functional.normalize(logits, dim=self.dim, p=2) + if self.scale_logits: + logits = torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * logits + return logits + + +class ImageBindPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ImageBindConfig + base_model_prefix = "imagebind" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + + def init_projection(proj, embed_dim): + nn.init.normal_(proj.weight, std=embed_dim**-0.5 * factor) + + if isinstance(module, ImageBindTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + + elif isinstance(module, (ImageBindVisionEmbeddings, ImageBindAudioEmbeddings)): + nn.init.normal_(module.cls_token, std=module.config.hidden_size**-0.5 * factor) + nn.init.normal_(module.patch_embedding.projection.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embeddings, std=module.config.initializer_range * factor) + + elif isinstance(module, ImageBindAttention): + layer_factor = (2 * module.config.num_hidden_layers) ** -0.5 + in_proj_std = (module.embed_dim**-0.5) * layer_factor * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.qkv_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + if module.k_bias is not None: + nn.init.normal_(module.k_bias, std=in_proj_std) + if module.v_bias is not None: + nn.init.normal_(module.v_bias, std=in_proj_std) + + elif isinstance(module, ImageBindMlp): + layer_factor = (2 * module.config.num_hidden_layers) ** -0.5 + in_proj_std = (module.config.hidden_size**-0.5) * layer_factor * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, ImageBindModel): + init_projection(module.text_projection, module.text_embed_dim) + init_projection(module.vision_projection, module.vision_embed_dim) + init_projection(module.audio_projection, module.audio_embed_dim) + for config, modality in zip( + [self.config.text_config, self.config.vision_config, self.config.audio_config], + ["text", "vision", "audio"], + ): + if config.logit_scale_init_value is not None and config.learnable_logit_scale: + logit_scale = torch.ones([]) * np.log(config.logit_scale_init_value) * factor + getattr(module, f"{modality}_postprocessor").log_logit_scale = nn.Parameter(logit_scale) + elif isinstance( + module, + (ImageBindVisionModelWithProjection, ImageBindTextModelWithProjection, ImageBindAudioModelWithProjection), + ): + modality = module.__class__.__name__.replace("ModelWithProjection", "").replace("ImageBind", "").lower() + init_projection(getattr(module, f"{modality}_projection"), self.config.hidden_size) + if self.config.logit_scale_init_value is not None and self.config.learnable_logit_scale: + logit_scale = torch.ones([]) * np.log(self.config.logit_scale_init_value) * factor + getattr(module, f"{modality}_postprocessor").log_logit_scale = nn.Parameter(logit_scale) + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ImageBindEncoder): + module.gradient_checkpointing = value + + +IMAGEBIND_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ImageBindConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +IMAGEBIND_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +IMAGEBIND_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`ImageBindImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +IMAGEBIND_AUDIO_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, num_mel_bins, target_len)`): + Input features. Padding will be ignored by default should you provide it. Input features can be obtained + using [`AutoFeatureExtractor`]. See [`ImageBindFeatureExtractor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +IMAGEBIND_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`ImageBindImageProcessor.__call__`] for details. + input_features (`torch.FloatTensor` of shape `(batch_size, num_mel_bins, target_len)`): + Input features. Padding will be ignored by default should you provide it. Input features can be obtained + using [`AutoFeatureExtractor`]. See [`ImageBindFeatureExtractor.__call__`] for details. + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class ImageBindEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`ImageBindEncoderLayer`]. This is [`CLIPEncoder`] with DropPath support + + Args: + config: ImageBindConfig + """ + + def __init__(self, config: ImageBindConfig): + super().__init__() + self.config = config + + drop_path_rates = [prob.item() for prob in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] + self.layers = nn.ModuleList( + [ImageBindEncoderLayer(config, drop_path_rate) for drop_path_rate in drop_path_rates] + ) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class ImageBindTextTransformer(nn.Module): + def __init__(self, config: ImageBindTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = ImageBindTextEmbeddings(config) + self.encoder = ImageBindEncoder(config) + self.layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(IMAGEBIND_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindTransformerOutput, config_class=ImageBindTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindTransformerOutput]: + r""" + Returns: + Union[Tuple, ImageBindTransformerOutput] + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + batch_size, seq_len = input_shape + + attention_mask = self._build_attention_mask( + attention_mask, batch_size, seq_len, hidden_states.dtype, hidden_states.device + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.layernorm(last_hidden_state) + + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return ImageBindTransformerOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype): + # Expand and invert the mask, then fill masked areas + return (1.0 - mask[:, None, None, :].to(dtype)).masked_fill(mask[:, None, None, :].to(dtype) == 0, torch.finfo(dtype).min) + + def _build_attention_mask(self, attention_mask, batch_size, seq_len, dtype, device=None): + # Build causal mask + mask = torch.empty(batch_size, seq_len, seq_len, dtype=dtype, device=device) + mask.fill_(torch.finfo(dtype).min) + mask.triu_(1) + mask = mask.unsqueeze(1) # expand mask + + # If attention_mask update causal mask + if attention_mask is not None: + attention_mask = self._expand_mask(attention_mask, dtype) + return mask + attention_mask + return mask + + +@add_start_docstrings( + """The text model from ImageBind without any head or projection on top.""", + IMAGEBIND_START_DOCSTRING, +) +class ImageBindTextModel(ImageBindPreTrainedModel): + config_class = ImageBindTextConfig + + _no_split_modules = ["ImageBindEncoderLayer"] + + def __init__(self, config: ImageBindTextConfig): + super().__init__(config) + self.text_model = ImageBindTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(IMAGEBIND_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindTransformerOutput, config_class=ImageBindTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindTransformerOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, ImageBindTextModel + + >>> model = ImageBindTextModel.from_pretrained("EduardoPacheco/imagebind-huge") + >>> tokenizer = AutoTokenizer.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ImageBindVisionTransformer(nn.Module): + def __init__(self, config: ImageBindVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = ImageBindVisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = ImageBindEncoder(config) + self.layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(IMAGEBIND_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindTransformerOutput, config_class=ImageBindVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindTransformerOutput]: + r""" + Returns: + Union[Tuple, ImageBindTransformerOutput]: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # For video inputs we take multiple clips and average the embeddings + # See https://github.com/facebookresearch/ImageBind/blob/main/imagebind/models/imagebind_model.py#L470 + reduce_clips = pixel_values.ndim >= 5 + if reduce_clips: + batch_size, num_clips = pixel_values.shape[:2] + pixel_values = pixel_values.reshape(batch_size * num_clips, *pixel_values.shape[2:]) + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layernorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return ImageBindTransformerOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from ImageBind without any head or projection on top.""", + IMAGEBIND_START_DOCSTRING, +) +class ImageBindVisionModel(ImageBindPreTrainedModel): + config_class = ImageBindVisionConfig + _no_split_modules = ["ImageBindEncoderLayer"] + + main_input_name = "pixel_values" + + def __init__(self, config: ImageBindVisionConfig): + super().__init__(config) + self.vision_model = ImageBindVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(IMAGEBIND_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindTransformerOutput, config_class=ImageBindVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindTransformerOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, ImageBindVisionModel + + >>> model = ImageBindVisionModel.from_pretrained("EduardoPacheco/imagebind-huge") + >>> processor = AutoProcessor.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ImageBindAudioTransformer(nn.Module): + def __init__(self, config: ImageBindAudioConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = ImageBindAudioEmbeddings(config) + self.encoder = ImageBindEncoder(config) + self.layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(IMAGEBIND_AUDIO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindTransformerOutput, config_class=ImageBindAudioConfig) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindTransformerOutput]: + r""" + Returns: + Union[Tuple, ImageBindTransformerOutput] + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_features is None: + raise ValueError("You have to specify input_features") + + # If audio is chunked (i.e. same audio is split into multiple clips), reduce embedding over clips dimension + # See https://github.com/facebookresearch/ImageBind/blob/main/imagebind/models/imagebind_model.py#L470 + reduce_clips = input_features.ndim >= 5 + if reduce_clips: + batch_size, num_clips = input_features.shape[:2] + input_features = input_features.reshape(batch_size * num_clips, *input_features.shape[2:]) + + hidden_states = self.embeddings(input_features) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return ImageBindTransformerOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from ImageBind without any head or projection on top.""", + IMAGEBIND_START_DOCSTRING, +) +class ImageBindAudioModel(ImageBindPreTrainedModel): + config_class = ImageBindAudioConfig + _no_split_modules = ["ImageBindEncoderLayer"] + + main_input_name = "input_features" + + def __init__(self, config: ImageBindAudioConfig): + super().__init__(config) + self.audio_model = ImageBindAudioTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.audio_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(IMAGEBIND_AUDIO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindTransformerOutput, config_class=ImageBindAudioConfig) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindTransformerOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import torchaudio + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, ImageBindAudioModel + + >>> ds = load_dataset("EduardoPacheco/imagebind-example-data", split="train") + >>> audio = ds[0]["audio"] + >>> audio = torchaudio.functional.resample(torch.from_numpy(audio["array"]), orig_freq=audio["sampling_rate"], new_freq=16000).numpy() + + >>> model = ImageBindAudioModel.from_pretrained("EduardoPacheco/imagebind-huge") + >>> processor = AutoProcessor.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> inputs = processor(audios=audio, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.audio_model( + input_features=input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(IMAGEBIND_START_DOCSTRING) +class ImageBindModel(ImageBindPreTrainedModel): + config_class = ImageBindConfig + main_input_name = "pixel_values" + + def __init__(self, config: ImageBindConfig): + super().__init__(config) + + if not isinstance(config.text_config, ImageBindTextConfig): + raise ValueError( + "config.text_config is expected to be of type ImageBindTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, ImageBindVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type ImageBindVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + if not isinstance(config.audio_config, ImageBindAudioConfig): + raise ValueError( + "config.audio_config is expected to be of type ImageBindAudioConfig but is of type" + f" {type(config.audio_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + audio_config = config.audio_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + self.audio_embed_dim = audio_config.hidden_size + + self.text_model = ImageBindTextTransformer(text_config) + self.vision_model = ImageBindVisionTransformer(vision_config) + self.audio_model = ImageBindAudioTransformer(audio_config) + + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.vision_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.audio_projection = nn.Linear(self.audio_embed_dim, self.projection_dim, bias=False) + + self.text_postprocessor = ImageBindPostProcessor(text_config) + self.vision_postprocessor = ImageBindPostProcessor(vision_config) + self.audio_postprocessor = ImageBindPostProcessor(audio_config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(IMAGEBIND_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`ImageBindTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, ImageBindModel + + >>> model = ImageBindModel.from_pretrained("EduardoPacheco/imagebind-huge") + >>> tokenizer = AutoTokenizer.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use ImageBind model's config for some fields (if specified) instead of those in the text component. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(IMAGEBIND_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`ImageBindVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, ImageBindModel + + >>> model = ImageBindModel.from_pretrained("EduardoPacheco/imagebind-huge") + >>> processor = AutoProcessor.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use ImageBind model's config for some fields (if specified) instead of those in the vision components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.shape[0] + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.vision_projection(pooled_output) + + if pixel_values.ndim >= 5: + num_clips = pixel_values.shape[1] + image_features = image_features.reshape(batch_size, num_clips, -1) + # Take mean over all clips + image_features = image_features.mean(dim=1) + + return image_features + + @add_start_docstrings_to_model_forward(IMAGEBIND_AUDIO_INPUTS_DOCSTRING) + def get_audio_features( + self, + input_features: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + audio_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The audio embeddings obtained by + applying the projection layer to the pooled output of [`ImageBindAudioModel`]. + + Examples: + + ```python + >>> import torchaudio + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, ImageBindModel + + >>> ds = load_dataset("EduardoPacheco/imagebind-example-data", split="train") + >>> audio = ds[0]["audio"] + >>> audio = torchaudio.functional.resample(torch.from_numpy(audio["array"]), orig_freq=audio["sampling_rate"], new_freq=16000).numpy() + + >>> model = ImageBindModel.from_pretrained("EduardoPacheco/imagebind-huge") + >>> processor = AutoProcessor.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> inputs = processor(audios=audio, return_tensors="pt") + + >>> audio_features = model.get_audio_features(**inputs) + ```""" + # Use ImageBind model's config for some fields (if specified) instead of those in the audio component. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_features.shape[0] + + audio_outputs = self.audio_model( + input_features=input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = audio_outputs[1] # pooled_output + audio_features = self.audio_projection(pooled_output) + + # If audio is chunked (i.e. same audio is split into multiple clips), reduce embedding over clips dimension + # See https://github.com/facebookresearch/ImageBind/blob/main/imagebind/models/imagebind_model.py#L470 + if input_features.ndim >= 5: + num_clips = input_features.shape[1] + audio_features = audio_features.reshape(batch_size, num_clips, -1) + # Take mean over all clips + audio_features = audio_features.mean(dim=1) + + return audio_features + + @add_start_docstrings_to_model_forward(IMAGEBIND_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindOutput, config_class=ImageBindConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + input_features: Optional[torch.Tensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, ImageBindModel + + >>> model = ImageBindModel.from_pretrained("EduardoPacheco/imagebind-huge") + >>> processor = AutoProcessor.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # We expect a combination of pixel_values and one of the other inputs i.e. input_features or input_ids should be provided + if input_ids is None and input_features is None: + raise ValueError("At least one of `input_ids` or `input_features` should be provided.") + + # We expect only one of input_features or input_ids to be provided + if input_ids is not None and input_features is not None: + raise ValueError("Only one of `input_ids` or `input_features` should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # running the vision model + image_batch_size = pixel_values.shape[0] + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.vision_projection(image_embeds) + image_embeds = self.vision_postprocessor(image_embeds) + + # For video inputs we take multiple clips and average the embeddings + # See https://github.com/facebookresearch/ImageBind/blob/main/imagebind/models/imagebind_model.py#L470 + if pixel_values.ndim >= 5: + image_num_clips = pixel_values.shape[1] + image_embeds = image_embeds.reshape(image_batch_size, image_num_clips, -1) + image_embeds = image_embeds.mean(dim=1) + + # running the text model if input_ids is provided + text_embeds = None + logits_per_text = None + text_outputs = None + if input_ids is not None: + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + text_embeds = self.text_postprocessor(text_embeds) + + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) + logits_per_image = logits_per_text.t() + + # running the audio model if input_features is provided + audio_embeds = None + logits_per_audio = None + audio_outputs = None + if input_features is not None: + audio_batch_size = input_features.shape[0] + audio_outputs = self.audio_model( + input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + audio_embeds = audio_outputs[1] + audio_embeds = self.audio_projection(audio_embeds) + audio_embeds = self.audio_postprocessor(audio_embeds) + + if input_features.ndim >= 5: + num_clips = input_features.shape[1] + audio_embeds = audio_embeds.reshape(audio_batch_size, num_clips, -1) + audio_embeds = audio_embeds.mean(dim=1) + + logits_per_audio = torch.matmul(audio_embeds, image_embeds.t()) + logits_per_image = logits_per_audio.t() + + loss = None + if return_loss: + loss = imagebind_loss(logits_per_text) if logits_per_text is not None else imagebind_loss(logits_per_audio) + + if not return_dict: + output = ( + logits_per_image, + logits_per_text, + logits_per_audio, + image_embeds, + text_embeds, + audio_embeds, + vision_outputs, + text_outputs, + audio_outputs, + ) + output = tuple([out for out in output if out is not None]) + return ((loss,) + output) if loss is not None else output + + return ImageBindOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + logits_per_audio=logits_per_audio, + image_embeds=image_embeds, + text_embeds=text_embeds, + audio_embeds=audio_embeds, + vision_model_output=vision_outputs, + text_model_output=text_outputs, + audio_model_output=audio_outputs, + ) + + +@add_start_docstrings( + """ + ImageBind Text Model with a projection layer on top (a linear layer on top of the pooled output). + """, + IMAGEBIND_START_DOCSTRING, +) +class ImageBindTextModelWithProjection(ImageBindPreTrainedModel): + config_class = ImageBindTextConfig + + _no_split_modules = ["ImageBindEncoderLayer"] + + def __init__(self, config: ImageBindTextConfig): + super().__init__(config) + + self.text_model = ImageBindTextTransformer(config) + + self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + + self.text_postprocessor = ImageBindPostProcessor(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(IMAGEBIND_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindTextModelOutput, config_class=ImageBindTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindTextModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, ImageBindTextModelWithProjection + + >>> model = ImageBindTextModelWithProjection.from_pretrained("EduardoPacheco/imagebind-huge") + >>> tokenizer = AutoTokenizer.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> text_embeds = outputs.text_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + text_embeds = self.text_projection(pooled_output) + normalized_text_embeds = self.text_postprocessor(text_embeds) + + if not return_dict: + outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] + (normalized_text_embeds,) + return tuple(output for output in outputs if output is not None) + + return ImageBindTextModelOutput( + text_embeds=text_embeds, + last_hidden_state=text_outputs.last_hidden_state, + hidden_states=text_outputs.hidden_states, + attentions=text_outputs.attentions, + normalized_text_embeds=normalized_text_embeds, + ) + + +@add_start_docstrings( + """ + ImageBind Vision Model with a projection layer on top (a linear layer on top of the pooled output). + """, + IMAGEBIND_START_DOCSTRING, +) +class ImageBindVisionModelWithProjection(ImageBindPreTrainedModel): + config_class = ImageBindVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: ImageBindVisionConfig): + super().__init__(config) + + self.vision_model = ImageBindVisionTransformer(config) + + self.vision_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + + self.vision_postprocessor = ImageBindPostProcessor(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(IMAGEBIND_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindVisionModelOutput, config_class=ImageBindVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindVisionModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, ImageBindVisionModelWithProjection + + >>> model = ImageBindVisionModelWithProjection.from_pretrained("EduardoPacheco/imagebind-huge") + >>> processor = AutoProcessor.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> image_embeds = outputs.image_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.shape[0] + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + + image_embeds = self.vision_projection(pooled_output) + normalized_image_embeds = self.vision_postprocessor(image_embeds) + + if pixel_values.ndim >= 5: + num_clips = pixel_values.shape[1] + image_embeds = image_embeds.reshape(batch_size, num_clips, -1) + # Take mean over all clips + image_embeds = image_embeds.mean(dim=1) + + normalized_image_embeds = normalized_image_embeds.reshape(batch_size, num_clips, -1) + normalized_image_embeds = normalized_image_embeds.mean(dim=1) + + if not return_dict: + outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:] + (normalized_image_embeds,) + return tuple(output for output in outputs if output is not None) + + return ImageBindVisionModelOutput( + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + normalized_image_embeds=normalized_image_embeds, + ) + + +@add_start_docstrings( + """ + ImageBind Audio Model with a projection layer on top (a linear layer on top of the pooled output). + """, + IMAGEBIND_START_DOCSTRING, +) +class ImageBindAudioModelWithProjection(ImageBindPreTrainedModel): + config_class = ImageBindAudioConfig + main_input_name = "input_features" + + def __init__(self, config: ImageBindAudioConfig): + super().__init__(config) + + self.audio_model = ImageBindAudioTransformer(config) + + self.audio_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + + self.audio_postprocessor = ImageBindPostProcessor(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.audio_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(IMAGEBIND_AUDIO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageBindAudioModelOutput, config_class=ImageBindAudioConfig) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageBindAudioModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import torch + >>> import torchaudio + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, ImageBindAudioModelWithProjection + + >>> ds = load_dataset("EduardoPacheco/imagebind-example-data", split="train") + >>> audio = ds[0]["audio"] + >>> audio = torchaudio.functional.resample(torch.from_numpy(audio["array"]), orig_freq=audio["sampling_rate"], new_freq=16000).numpy() + + >>> model = ImageBindAudioModelWithProjection.from_pretrained("EduardoPacheco/imagebind-huge") + >>> processor = AutoProcessor.from_pretrained("EduardoPacheco/imagebind-huge") + + >>> inputs = processor(audios=audio, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> audio_embeds = outputs.audio_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_features.shape[0] + + audio_outputs = self.audio_model( + input_features=input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = audio_outputs[1] # pooled_output + + audio_embeds = self.audio_projection(pooled_output) + normalized_audio_embeds = self.audio_postprocessor(audio_embeds) + + if input_features.ndim >= 5: + num_clips = input_features.shape[1] + audio_embeds = audio_embeds.reshape(batch_size, num_clips, -1) + # Take mean over all clips + audio_embeds = audio_embeds.mean(dim=1) + + normalized_audio_embeds = normalized_audio_embeds.reshape(batch_size, num_clips, -1) + normalized_audio_embeds = normalized_audio_embeds.mean(dim=1) + + if not return_dict: + outputs = (audio_embeds, audio_outputs[0]) + audio_outputs[2:] + (normalized_audio_embeds,) + return tuple(output for output in outputs if output is not None) + + return ImageBindAudioModelOutput( + audio_embeds=audio_embeds, + last_hidden_state=audio_outputs.last_hidden_state, + hidden_states=audio_outputs.hidden_states, + attentions=audio_outputs.attentions, + normalized_audio_embeds=normalized_audio_embeds, + ) + + +__all__ = [ + "ImageBindTextModel", + "ImageBindVisionModel", + "ImageBindAudioModel", + "ImageBindPreTrainedModel", + "ImageBindModel", + "ImageBindTextModelWithProjection", + "ImageBindVisionModelWithProjection", + "ImageBindAudioModelWithProjection", +] diff --git a/src/transformers/models/imagebind/processing_imagebind.py b/src/transformers/models/imagebind/processing_imagebind.py new file mode 100644 index 000000000000..d507afbbbeb9 --- /dev/null +++ b/src/transformers/models/imagebind/processing_imagebind.py @@ -0,0 +1,194 @@ +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for ImageBind +""" + +from typing import List, Optional, Union + +import numpy as np + +from ...image_utils import ( + ImageInput, + VideoInput, +) +from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import ( + BatchEncoding, + PreTokenizedInput, + TextInput, +) + + +class ImageBindProcessorImagesKwargs(ImagesKwargs, total=False): + do_convert_rgb: bool + do_chunk: bool + chunk_duration: float + num_chunks: int + num_frames_per_chunk: int + fps: int + + +class ImageBindProcessorAudioKwargs(AudioKwargs, total=False): + do_normalize: Optional[bool] + mean: Optional[float] + std: Optional[float] + do_chunk: Optional[bool] + chunk_duration: Optional[float] + num_chunks: Optional[int] + + +class ImageBindProcessorKwargs(ProcessingKwargs, total=False): + # see processing_utils.ProcessingKwargs documentation for usage. + images_kwargs: ImageBindProcessorImagesKwargs + audio_kwargs: ImageBindProcessorAudioKwargs + _defaults = { + "images_kwargs": { + "do_convert_rgb": True, + "do_chunk": True, + "chunk_duration": 2.0, + "num_chunks": 5, + "num_frames_per_chunk": 2, + "fps": 30, + }, + "audio_kwargs": { + "sampling_rate": 16000, + "do_normalize": True, + "mean": -4.268, + "std": 9.138, + "do_chunk": True, + "chunk_duration": 2.0, + "num_chunks": 3, + }, + } + + +class ImageBindProcessor(ProcessorMixin): + r""" + Constructs a ImageBind processor which wraps a ImageBind image processor and feature extracotr and a CLIP tokenizer into a single processor. + + [`ImageBindProcessor`] offers all the functionalities of [`ImageBindImageProcessor`], [`ImageBindFeatureExtractor`] and [`CLIPTokenizerFast`]. + See the [`~ImageBindProcessor.__call__`] and [`~ImageBindProcessor.decode`] for more information. + + Args: + image_processor ([`ImageBindImageProcessor`]): + An instance of [`ImageBindImageProcessor`] to process the images. This is a required input. + tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]): + An instance of ['PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]. The tokenizer is a required input. + feature_extractor ([`ImageBindFeatureExtractor`]): + An instance of [`ImageBindFeatureExtractor`] to extract features from the audio. This is a required input. + """ + + attributes = ["image_processor", "tokenizer", "feature_extractor"] + image_processor_class = "ImageBindImageProcessor" + feature_extractor_class = "ImageBindFeatureExtractor" + tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + + def __init__(self, image_processor, tokenizer, feature_extractor): + super().__init__(image_processor, tokenizer, feature_extractor) + + def __call__( + self, + images: Optional[ImageInput] = None, + videos: Optional[VideoInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]] = None, + **kwargs: Unpack[ImageBindProcessorKwargs], + ) -> BatchEncoding: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to ImageBindTokenizerFast's [`~ImageBindTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + ImageBindImageProcessor's [`~ImageBindImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + Args: + images (`ImageInput`, *optional*): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is + number of channels, H and W are image's height and width. + videos (`VideoInput`, *optional*): + Video frames to preprocess. Expects a single or batch of video frames in PIL images, NumPy array, PyTorch + tensor or Lists. Each video should be of shape (T, C, H, W), where T is number of frames, C is + number of channels, H and W are image height and width. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + audio (`AudioInput`, `List[float]`, `List[List[float]]`, `List[List[List[float]]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of numpy + arrays or a (possibly nested) list of float values. The supported input types are as follows: + + - unbatched: `List[float]`, `np.ndarray` (`ndim=1`) + - batched: `List[List[float]]`, `List[np.ndarray]` (`ndim=1`), `np.ndarray` (`ndim=2`) + - batched with clips: `List[List[List[float]]]`, `List[List[np.ndarray]]` (`ndim=1`), `List[np.ndarray]` (`ndim=2`), np.ndarray (`ndim=3`) + + The input will always be interpreted as mono channel audio, not stereo, i.e. a single float per timestep. + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` or `videos` is not `None`. + - **input_features** -- List of input features to be fed to a model. Returned when `audio` is not `None`. + """ + + if text is None and images is None and videos is None and audio is None: + raise ValueError("You have to specify either text, images, videos or audio. All cannot be none.") + + output_kwargs = self._merge_kwargs( + ImageBindProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + data = {} + + if text is not None: + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + data.update(encoding) + + if images is not None or videos is not None: + image_features = self.image_processor(images=images, videos=videos, **output_kwargs["images_kwargs"]) + data.update(image_features) + + if audio is not None: + audio_features = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + data.update(audio_features) + + return BatchEncoding(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to ImageBindTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to ImageBindTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + feature_extractor_input_names)) + + +__all__ = ["ImageBindProcessor"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d7570c57c62f..873b8dd0bb48 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4908,6 +4908,62 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ImageBindAudioModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ImageBindAudioModelWithProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ImageBindModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ImageBindPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ImageBindTextModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ImageBindTextModelWithProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ImageBindVisionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ImageBindVisionModelWithProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ImageGPTForCausalImageModeling(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index d2ccaeaaed23..e2d53455a301 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -303,6 +303,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class ImageBindImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class ImageGPTFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/imagebind/__init__.py b/tests/models/imagebind/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/imagebind/test_feature_extraction_imagebind.py b/tests/models/imagebind/test_feature_extraction_imagebind.py new file mode 100644 index 000000000000..4092978f228d --- /dev/null +++ b/tests/models/imagebind/test_feature_extraction_imagebind.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import os +import random +import tempfile +import unittest + +import numpy as np + +from transformers import ImageBindFeatureExtractor +from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio +from transformers.utils.import_utils import is_speech_available, is_torch_available + +from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin + + +global_rng = random.Random() + +if is_torch_available(): + import torch + +if is_speech_available(): + import torchaudio + + +# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list +def floats_list(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + values = [] + for batch_idx in range(shape[0]): + values.append([]) + for _ in range(shape[1]): + values[-1].append(rng.random() * scale) + + return values + + +class ImageBindFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + min_seq_length=400, + max_seq_length=2000, + feature_size=1, + padding_value=0.0, + sampling_rate=16000, + do_normalize=True, + return_attention_mask=False, + ): + self.parent = parent + self.batch_size = batch_size + self.min_seq_length = min_seq_length + self.max_seq_length = max_seq_length + self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1) + self.feature_size = feature_size + self.padding_value = padding_value + self.sampling_rate = sampling_rate + self.return_attention_mask = return_attention_mask + self.do_normalize = do_normalize + + def prepare_feat_extract_dict(self): + return { + "feature_size": self.feature_size, + "padding_value": self.padding_value, + "sampling_rate": self.sampling_rate, + "return_attention_mask": self.return_attention_mask, + "do_normalize": self.do_normalize, + } + + def prepare_inputs_for_common(self, equal_length=False, numpify=False): + def _flatten(list_of_lists): + return list(itertools.chain(*list_of_lists)) + + if equal_length: + speech_inputs = floats_list((self.batch_size, self.max_seq_length)) + else: + # make sure that inputs increase in size + speech_inputs = [ + _flatten(floats_list((x, self.feature_size))) + for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff) + ] + + if numpify: + speech_inputs = [np.asarray(x) for x in speech_inputs] + + return speech_inputs + + +@require_torch +@require_torchaudio +class ImageBindFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): + feature_extraction_class = ImageBindFeatureExtractor + + def setUp(self): + self.feat_extract_tester = ImageBindFeatureExtractionTester(self) + + def test_call(self): + # Tests that all call wrap to encode_plus and batch_encode_plus + feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + # create three inputs of length 800, 1000, and 1200 + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs] + + # Test not batched input + encoded_sequences_1 = feat_extract(speech_inputs[0], return_tensors="np").input_features + encoded_sequences_2 = feat_extract(np_speech_inputs[0], return_tensors="np").input_features + self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + # Test batched + encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_features + self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_features + self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + # Test 3-D numpy arrays are batched and chunked. + speech_inputs = [[floats_list((1, x))[0]] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_features + self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + def _load_datasamples(self): + from datasets import load_dataset + + ds = load_dataset("EduardoPacheco/imagebind-example-data", split="train") + audios = [ + torchaudio.functional.resample( + torch.from_numpy(audio["array"]), + orig_freq=audio["sampling_rate"], + new_freq=self.feat_extract_tester.sampling_rate, + ).numpy() + for audio in ds["audio"] + ] + + return audios + + @require_torch + def test_integration(self): + # fmt: off + expected_input1 = torch.tensor( + [[-1.2776, -0.9167, -1.2776], + [-1.2439, -0.8372, -0.8748], + [-1.1235, -0.7492, -1.0867]] + ) + expected_input2 = torch.tensor( + [[-1.1474, -0.5601, -0.1045], + [0.0730, 0.0503, 0.0564], + [-0.1738, 0.0505, -0.2641]] + ) + # fmt: on + + input_speech = self._load_datasamples() + feature_extractor = ImageBindFeatureExtractor() + input_values = feature_extractor(input_speech, return_tensors="pt").input_features + expected_shape = ( + len(input_speech), + feature_extractor.num_chunks, + 1, + feature_extractor.num_mel_bins, + feature_extractor.max_length, + ) + self.assertEqual(input_values.shape, expected_shape) + self.assertTrue(torch.allclose(input_values[:, :, 0, 0, 0], expected_input1, atol=1e-4)) + self.assertTrue(torch.allclose(input_values[:, :, 0, 111, 0], expected_input2, atol=1e-4)) + + def test_feat_extract_from_and_save_pretrained(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + saved_file = feat_extract_first.save_pretrained(tmpdirname)[0] + check_json_file_has_correct_format(saved_file) + feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + self.assertDictEqual(dict_first, dict_second) + + def test_feat_extract_to_json_file(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + json_file_path = os.path.join(tmpdirname, "feat_extract.json") + feat_extract_first.to_json_file(json_file_path) + feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + self.assertEqual(dict_first, dict_second) + + +# exact same tests than before, except that we simulate that torchaudio is not available +@require_torch +@unittest.mock.patch( + "transformers.models.imagebind.feature_extraction_imagebind.is_speech_available", + lambda: False, +) +class ImageBindFeatureExtractionWithoutTorchaudioTest(ImageBindFeatureExtractionTest): + def test_using_audio_utils(self): + # Tests that it uses audio_utils instead of torchaudio + feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + + self.assertTrue(hasattr(feat_extract, "window")) + self.assertTrue(hasattr(feat_extract, "mel_filters")) + + from transformers.models.imagebind.feature_extraction_imagebind import ( + is_speech_available, + ) + + self.assertFalse(is_speech_available()) diff --git a/tests/models/imagebind/test_image_processing_imagebind.py b/tests/models/imagebind/test_image_processing_imagebind.py new file mode 100644 index 000000000000..b4cd2321fe2f --- /dev/null +++ b/tests/models/imagebind/test_image_processing_imagebind.py @@ -0,0 +1,302 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import ImageBindImageProcessor + + +class ImageBindImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=5, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=80, + do_resize=True, + size=None, + do_center_crop=True, + crop_size=None, + do_normalize=True, + image_mean=OPENAI_CLIP_MEAN, + image_std=OPENAI_CLIP_STD, + do_convert_rgb=True, + ): + size = size if size is not None else {"shortest_edge": 20} + crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + } + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape + def expected_output_image_shape(self, images): + return self.num_channels, self.crop_size["height"], self.crop_size["width"] + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False): + images = prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + # let's simply copy the frames to fake a long video-clip + if numpify or torchify: + videos = [] + for image in images: + if numpify: + video = image[None, ...].repeat(8, 0) + else: + video = image[None, ...].repeat(8, 1, 1, 1) + videos.append(video) + else: + videos = [] + for pil_image in images: + videos.append([pil_image] * 8) + + return videos + + +@require_torch +@require_vision +class ImageBindImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = ImageBindImageProcessor if is_vision_available() else None + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->ImageBind + def setUp(self): + super().setUp() + self.image_processor_tester = ImageBindImageProcessingTester(self) + + @property + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + + image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + + def test_call_pil(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (5, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_numpy(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(images=image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(images=image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (5, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_numpy_videos(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + video_inputs = self.image_processor_tester.prepare_video_inputs(numpify=True, equal_resolution=True) + for video in video_inputs: + self.assertIsInstance(video, np.ndarray) + + # Test not batched input + encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values + expected_output_video_shape = (1, 5, 3, 2, 18, 18) + self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + + # Test batched + encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values + expected_output_video_shape = (5, 5, 3, 2, 18, 18) + self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + + def test_call_pil_videos(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # the inputs come in list of lists batched format + video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True) + for video in video_inputs: + self.assertIsInstance(video[0], Image.Image) + + # Test not batched input + encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values + expected_output_video_shape = (1, 5, 3, 2, 18, 18) + self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + + # Test batched + encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values + expected_output_video_shape = (5, 5, 3, 2, 18, 18) + self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + + def test_call_pytorch(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (5, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_pytorch_videos(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, torchify=True) + for video in video_inputs: + self.assertIsInstance(video, torch.Tensor) + + # Test not batched input + encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values + expected_output_video_shape = (1, 5, 3, 2, 18, 18) + self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + + # Test batched + encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values + expected_output_video_shape = (5, 5, 3, 2, 18, 18) + self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + + def test_call_numpy_4_channels(self): + # Test that can process images which have an arbitrary number of channels + # Initialize image_processing + image_processor = self.image_processing_class(**self.image_processor_dict) + + # create random numpy tensors + self.image_processor_tester.num_channels = 4 + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + + # Test not batched input + encoded_images = image_processor( + image_inputs[0], + return_tensors="pt", + input_data_format="channels_last", + image_mean=0, + image_std=1, + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processor( + image_inputs, + return_tensors="pt", + input_data_format="channels_last", + image_mean=0, + image_std=1, + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) diff --git a/tests/models/imagebind/test_modeling_imagebind.py b/tests/models/imagebind/test_modeling_imagebind.py new file mode 100644 index 000000000000..0a196cefd348 --- /dev/null +++ b/tests/models/imagebind/test_modeling_imagebind.py @@ -0,0 +1,1008 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch ImageBind model.""" + +import inspect +import os +import tempfile +import unittest + +import numpy as np +from datasets import load_dataset +from torchvision import transforms + +from transformers import ( + CLIPTokenizer, + ImageBindAudioConfig, + ImageBindConfig, + ImageBindFeatureExtractor, + ImageBindImageProcessor, + ImageBindModel, + ImageBindProcessor, + ImageBindTextConfig, + ImageBindVisionConfig, +) +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, +) +from transformers.testing_utils import ( + require_torch, + require_torchaudio, + require_vision, + slow, + torch_device, +) +from transformers.utils import is_speech_available, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + random_attention_mask, +) +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import ( + ImageBindAudioModel, + ImageBindAudioModelWithProjection, + ImageBindModel, + ImageBindTextModel, + ImageBindTextModelWithProjection, + ImageBindVisionModel, + ImageBindVisionModelWithProjection, + ) + + +if is_vision_available(): + pass + +if is_speech_available(): + import torchaudio + + +class ImageBindTextModelTester: + def __init__( + self, + parent, + batch_size=12, + seq_length=7, + is_training=False, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + projection_dim=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + dropout=0.0, + attention_dropout=0.0, + max_position_embeddings=512, + layer_norm_eps=1e-6, + initializer_range=0.02, + logit_scale_init_value=14.2857, + learnable_logit_scale=True, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.logit_scale_init_value = logit_scale_init_value + self.learnable_logit_scale = learnable_logit_scale + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + if input_mask is not None: + batch_size, seq_length = input_mask.shape + rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + input_mask[batch_idx, :start_index] = 1 + input_mask[batch_idx, start_index:] = 0 + + config = self.get_config() + + return config, input_ids, input_mask + + def get_config(self): + return ImageBindTextConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + layer_norm_eps=self.layer_norm_eps, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + logit_scale_init_value=self.logit_scale_init_value, + learnable_logit_scale=self.learnable_logit_scale, + ) + + def create_and_check_model(self, config, input_ids, input_mask): + model = ImageBindTextModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_model_with_projection(self, config, input_ids, input_mask): + model = ImageBindTextModelWithProjection(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.text_embeds.shape, (self.batch_size, self.projection_dim)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, input_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class ImageBindTextModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ImageBindTextModel, ImageBindTextModelWithProjection) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_head_masking = False + + def setUp(self): + self.model_tester = ImageBindTextModelTester(self) + self.config_tester = ConfigTester(self, config_class=ImageBindTextConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_with_projection(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_with_projection(*config_and_inputs) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="ImageBind does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="ImageBindTextModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="ImageBindTextModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "EduardoPacheco/imagebind-huge" + model = ImageBindTextModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @slow + def test_model_with_projection_from_pretrained(self): + model_name = "EduardoPacheco/imagebind-huge" + model = ImageBindTextModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "text_projection")) + + +class ImageBindVisionModelTester: + def __init__( + self, + parent, + batch_size=12, + image_size=32, + patch_size=8, + num_channels=3, + hidden_size=32, + mlp_ratio=1.0, + projection_dim=32, + num_hidden_layers=5, + num_attention_heads=4, + is_training=False, + logit_scale_init_value=None, + learnable_logit_scale=False, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.hidden_size = hidden_size + self.mlp_ratio = mlp_ratio + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.is_training = is_training + self.logit_scale_init_value = logit_scale_init_value + self.learnable_logit_scale = learnable_logit_scale + self.scope = scope + + # Though in Vision we have a 3D conv the time dimension is always 1, thus we can use only spatial dimensions + num_patches = (image_size // patch_size) ** 2 + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + self.seq_length = num_patches + 1 + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return ImageBindVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + mlp_ratio=self.mlp_ratio, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + logit_scale_init_value=self.logit_scale_init_value, + learnable_logit_scale=self.learnable_logit_scale, + ) + + # TODO: fix image size and patch_size + def create_and_check_model(self, config, pixel_values): + model = ImageBindVisionModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + # TODO: fix image size and patch_size + def create_and_check_model_with_projection(self, config, pixel_values): + model = ImageBindVisionModelWithProjection(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class ImageBindVisionModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as IMAGEBIND does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (ImageBindVisionModel, ImageBindVisionModelWithProjection) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = ImageBindVisionModelTester(self) + self.config_tester = ConfigTester( + self, config_class=ImageBindVisionConfig, has_text_modality=False, hidden_size=37 + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="IMAGEBIND does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_with_projection(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_with_projection(*config_and_inputs) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="ImageBindVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="ImageBindVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + @slow + def test_model_from_pretrained(self): + model_name = "EduardoPacheco/imagebind-huge" + model = ImageBindVisionModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @slow + def test_model_with_projection_from_pretrained(self): + model_name = "EduardoPacheco/imagebind-huge" + model = ImageBindVisionModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "vision_projection")) + + +class ImageBindAudioModelTester: + def __init__( + self, + parent, + batch_size=12, + patch_size=8, + stride=8, + num_channels=1, + is_training=False, + num_mel_bins=32, + target_len=48, + hidden_size=32, + projection_dim=32, + num_hidden_layers=2, + num_attention_heads=2, + mlp_ratio=1.0, + add_kv_bias=True, + logit_scale_init_value=20.0, + learnable_logit_scale=False, + scope=None, + ): + self.parent = parent + # Input audio can be batched with multiple clips + self.num_clips = 3 + # If clips are batched then the batch size is multiplied by the number of clips + self.actual_batch_size = batch_size + self.batch_size = batch_size * self.num_clips # this will be used internally + self.patch_size = patch_size + self.stride = stride + self.num_channels = num_channels + self.is_training = is_training + self.num_mel_bins = num_mel_bins + self.target_len = target_len + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.add_kv_bias = add_kv_bias + self.logit_scale_init_value = logit_scale_init_value + self.learnable_logit_scale = learnable_logit_scale + self.scope = scope + + # In audio model the mel-spectogram image size is based on the number of mel bins and the target length + patches_along_height_dim = int((num_mel_bins - patch_size) / stride + 1) + patches_along_width_dim = int((target_len - patch_size) / stride + 1) + num_patches = patches_along_height_dim * patches_along_width_dim + + self.encoder_seq_length = num_patches + 1 + self.key_length = num_patches + 1 if not add_kv_bias else num_patches + 2 + + def prepare_config_and_inputs(self): + input_features = floats_tensor( + [self.actual_batch_size, self.num_clips, self.num_channels, self.num_mel_bins, self.target_len] + ) + config = self.get_config() + + return config, input_features + + def get_config(self): + return ImageBindAudioConfig( + patch_size=self.patch_size, + stride=self.stride, + num_channels=self.num_channels, + num_mel_bins=self.num_mel_bins, + target_len=self.target_len, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + mlp_ratio=self.mlp_ratio, + add_kv_bias=self.add_kv_bias, + logit_scale_init_value=self.logit_scale_init_value, + learnable_logit_scale=self.learnable_logit_scale, + ) + + def create_and_check_model(self, config, input_features): + model = ImageBindAudioModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size) + ) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_model_with_projection(self, config, input_features): + model = ImageBindAudioModelWithProjection(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size) + ) + self.parent.assertEqual(result.audio_embeds.shape, (self.actual_batch_size, self.projection_dim)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_features = config_and_inputs + inputs_dict = {"input_features": input_features} + return config, inputs_dict + + +@require_torch +class ImageBindAudioModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as IMAGEBIND does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (ImageBindAudioModel, ImageBindAudioModelWithProjection) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = ImageBindAudioModelTester(self) + self.config_tester = ConfigTester( + self, config_class=ImageBindAudioConfig, has_text_modality=False, hidden_size=37 + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="ImageBind does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_features"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_with_projection(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_with_projection(*config_and_inputs) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="ImageBindAudioModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="ImageBindAudioModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + @slow + def test_model_from_pretrained(self): + model_name = "EduardoPacheco/imagebind-huge" + model = ImageBindAudioModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @slow + def test_model_with_projection_from_pretrained(self): + model_name = "EduardoPacheco/imagebind-huge" + model = ImageBindAudioModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "audio_projection")) + + +class ImageBindModelTester: + def __init__( + self, + parent, + text_kwargs=None, + vision_kwargs=None, + audio_kwargs=None, + projection_dim=32, + modality="text", + is_training=True, + ): + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + if audio_kwargs is None: + audio_kwargs = {} + + self.parent = parent + self.text_model_tester = ImageBindTextModelTester(parent, **text_kwargs) + self.vision_model_tester = ImageBindVisionModelTester(parent, **vision_kwargs) + self.audio_model_tester = ImageBindAudioModelTester(parent, **audio_kwargs) + self.projection_dim = projection_dim + self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test + # This is to make things easier and reuse ImageBindModelTester for all modalities + self.modality = modality + self.is_training = is_training + + def prepare_config_and_inputs(self): + _, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + _, input_features = self.audio_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, pixel_values, input_ids, attention_mask, input_features + + def get_config(self): + return ImageBindConfig( + self.text_model_tester.get_config().to_dict(), + self.vision_model_tester.get_config().to_dict(), + self.audio_model_tester.get_config().to_dict(), + projection_dim=self.projection_dim, + ) + + def create_and_check_text_vision_pair(self, config, pixel_values, input_ids, attention_mask): + model = ImageBindModel(config).to(torch_device).eval() + with torch.no_grad(): + result = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) + self.parent.assertEqual( + result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size) + ) + self.parent.assertEqual( + result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size) + ) + + def create_and_check_audio_vision_pair(self, config, pixel_values, input_features): + model = ImageBindModel(config).to(torch_device).eval() + with torch.no_grad(): + result = model(pixel_values=pixel_values, input_features=input_features) + self.parent.assertEqual( + result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.audio_model_tester.batch_size) + ) + self.parent.assertEqual( + result.logits_per_audio.shape, (self.audio_model_tester.batch_size, self.vision_model_tester.batch_size) + ) + + def create_and_check_model(self, config, pixel_values, input_ids=None, attention_mask=None, input_features=None): + if self.modality == "text": + self.create_and_check_text_vision_pair( + config, + pixel_values, + input_ids, + attention_mask, + ) + elif self.modality == "audio": + self.create_and_check_audio_vision_pair(config, pixel_values, input_features) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, input_ids, attention_mask, input_features = config_and_inputs + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "input_features": input_features, + "return_loss": True, + } + + if self.modality == "text": + inputs_dict.pop("input_features") + elif self.modality == "audio": + inputs_dict.pop("input_ids") + inputs_dict.pop("attention_mask") + + return config, inputs_dict + + +@require_torch +class ImageBindModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (ImageBindModel,) if is_torch_available() else () + pipeline_model_mapping = {"feature-extraction": ImageBindModel} if is_torch_available() else {} + fx_compatible = False + test_torchscript = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + + def setUp(self): + self.model_tester = ImageBindModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="ImageBindModel does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + configs_no_init.return_dict = False + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + + try: + traced_model = torch.jit.trace(model, example_kwarg_inputs=inputs_dict) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + def test_load_vision_text_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Save ImageBindConfig and check if we can load ImageBindVisionConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = ImageBindVisionConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict()) + + # Save ImageBindConfig and check if we can load ImageBindTextConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + text_config = ImageBindTextConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) + + @unittest.skip(reason="ImageBindModel does not have input/output embeddings") + def test_model_get_set_embeddings(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "EduardoPacheco/imagebind-huge" + model = ImageBindModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +def prepare_inputs(): + ds = load_dataset("EduardoPacheco/imagebind-example-data", split="train") + images = ds["image"] + texts = ds["text"] + audios = [ + torchaudio.functional.resample( + torch.from_numpy(audio["array"]), orig_freq=audio["sampling_rate"], new_freq=16000 + ).numpy() + for audio in ds["audio"] + ] + + return images, texts, audios + + +@require_vision +@require_torchaudio +@require_torch +class ImageBindModelIntegrationTest(unittest.TestCase): + @slow + def test_inference(self): + model_name = "EduardoPacheco/imagebind-huge" + model = ImageBindModel.from_pretrained(model_name).to(torch_device) + processor = ImageBindProcessor.from_pretrained(model_name) + + original_image_processor = transforms.Compose( + [ + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + ), + ] + ) + + images, texts, audios = prepare_inputs() + inputs = processor(text=texts, images=images, audio=audios, padding=True, return_tensors="pt").to(torch_device) + + expected_input_features = torch.tensor( + [ + [-1.2776, -0.9167, -1.2776], + [-1.2439, -0.8372, -0.8748], + [-1.1235, -0.7492, -1.0867], + ], + device=torch_device, + ) + + expected_pixel_values = torch.tensor( + [[-0.1134, 0.7392, 1.3069], [-0.6244, 0.1089, 0.2688], [-0.8434, 0.1089, 0.9088]], device=torch_device + ) + + expected_input_ids = torch.tensor( + [[49406, 320, 3329, 49407, 49407], [49406, 320, 1615, 49407, 49407], [49406, 320, 1929, 269, 49407]], + device=torch_device, + ) + + expected_attention_mask = torch.tensor( + [[1, 1, 1, 1, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], device=torch_device + ) + + self.assertTrue(torch.allclose(inputs.input_features[:, :, 0, 0, 0], expected_input_features, atol=1e-4)) + self.assertTrue(torch.allclose(inputs.pixel_values[:, :, 0, 0], expected_pixel_values, atol=1e-4)) + self.assertTrue(torch.allclose(inputs.input_ids, expected_input_ids, atol=1e-4)) + self.assertTrue(torch.allclose(inputs.attention_mask, expected_attention_mask, atol=1e-4)) + + with torch.no_grad(): + outputs_vision_text = model( + pixel_values=inputs.pixel_values, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask + ) + outputs_vision_audio = model(pixel_values=inputs.pixel_values, input_features=inputs.input_features) + + expected_image_embeds = torch.tensor( + [ + [0.0188, 0.0075, 0.0532, 0.0326, -0.0159], + [0.0190, 0.0106, 0.0275, 0.0189, -0.0268], + [-0.0104, -0.0203, 0.0048, -0.0158, 0.0076], + ], + device=torch_device, + ) + expected_text_embeds = torch.tensor( + [ + [-1.3476, -1.5732, -0.7386, 9.7949, 0.5856], + [-0.4342, -0.9050, -4.2879, 7.4123, -0.4906], + [-1.0745, -4.0049, -1.0697, 5.8861, -0.7583], + ], + device=torch_device, + ) + expected_audio_embeds = torch.tensor( + [ + [0.3244, -0.3748, 0.3956, 0.5600, -0.1932], + [0.7091, 0.2073, -1.0133, 0.4689, -0.2142], + [-0.0281, -0.4922, 1.0057, 0.0459, -0.2271], + ], + device=torch_device, + ) + + self.assertTrue(torch.allclose(outputs_vision_text.image_embeds[:, :5], expected_image_embeds, atol=1e-4)) + self.assertTrue(torch.allclose(outputs_vision_text.text_embeds[:, :5], expected_text_embeds, atol=1e-4)) + self.assertTrue(torch.allclose(outputs_vision_audio.audio_embeds[:, :5], expected_audio_embeds, atol=1e-4)) + self.assertTrue(torch.allclose(outputs_vision_text.image_embeds, outputs_vision_audio.image_embeds, atol=1e-4)) + + expected_logits_per_audio = torch.tensor( + [[7.3541, 1.1908, 2.2897], [1.1930, 3.0097, 2.0238], [0.9584, 1.2224, 4.2325]], device=torch_device + ) + + expected_logits_per_image_with_text = torch.tensor( + [[23.6142, 19.1165, 13.2448], [12.1343, 23.4165, 11.8823], [15.8471, 20.1186, 24.8246]], + device=torch_device, + ) + + self.assertTrue(torch.allclose(outputs_vision_audio.logits_per_audio, expected_logits_per_audio, atol=1e-4)) + self.assertTrue( + torch.allclose(outputs_vision_text.logits_per_image, expected_logits_per_image_with_text, atol=1e-4) + ) + + del model + + torch.manual_seed(0) + config = ImageBindConfig() + model = ImageBindModel(config).to(torch_device) + model.eval() + + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + image_processor = ImageBindImageProcessor() + feature_extractor = ImageBindFeatureExtractor() + processor = ImageBindProcessor(image_processor, tokenizer, feature_extractor) + + inputs_audio_vision = processor(images=images, audio=audios, return_tensors="pt").to(torch_device) + inputs_text_vision = processor(images=images, text=texts, return_tensors="pt", padding=True).to(torch_device) + + expected_input_features = torch.tensor( + [ + [-1.2776, -0.9167, -1.2776], + [-1.2439, -0.8372, -0.8748], + [-1.1235, -0.7492, -1.0867], + ], + device=torch_device, + ) + + expected_pixel_values = torch.stack([original_image_processor(image) for image in images]).to(torch_device) + + assert torch.allclose(inputs_audio_vision["pixel_values"], expected_pixel_values, atol=1e-4) + assert torch.allclose(inputs_audio_vision["input_features"][:, :, 0, 0, 0], expected_input_features, atol=1e-4) + + expected_output_vision = torch.tensor( + [ + [0.0217, -0.0969, -0.0044, -0.0203, 0.0178], + [0.0347, -0.0987, -0.0190, -0.0034, 0.0352], + [0.0389, -0.0910, -0.0230, -0.0072, 0.0455], + ], + device=torch_device, + ) + expected_output_text = torch.tensor( + [ + [-0.1995, 0.2042, 0.7407, 0.5275, -0.4482], + [-0.1800, 0.2736, 0.5057, 0.4819, -0.5618], + [-0.2461, 0.2926, 0.4936, 0.4322, -0.2178], + ], + device=torch_device, + ) + expected_output_audio = torch.tensor( + [ + [-0.0882, -0.4557, 0.3396, 1.1183, -0.0692], + [-0.4186, -0.2179, 0.0913, 0.9061, -0.0390], + [-0.1190, -0.5368, 0.2956, 1.1277, 0.0037], + ], + device=torch_device, + ) + + outputs_text_vision = model(**inputs_text_vision) + outputs_audio_vision = model(**inputs_audio_vision) + + assert torch.allclose(outputs_text_vision.image_embeds[:, :5], expected_output_vision, atol=1e-3) + assert torch.allclose(outputs_text_vision.text_embeds[:, :5], expected_output_text, atol=1e-4) + assert torch.allclose(outputs_audio_vision.audio_embeds[:, :5], expected_output_audio, atol=1e-4) + assert torch.allclose(outputs_text_vision.image_embeds, outputs_audio_vision.image_embeds, atol=1e-4) diff --git a/tests/models/imagebind/test_processor_imagebind.py b/tests/models/imagebind/test_processor_imagebind.py new file mode 100644 index 000000000000..37a540a66d8a --- /dev/null +++ b/tests/models/imagebind/test_processor_imagebind.py @@ -0,0 +1,394 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import shutil +import tempfile +import unittest + +import numpy as np +import pytest + +from transformers import CLIPTokenizer, CLIPTokenizerFast, ImageBindFeatureExtractor +from transformers.testing_utils import require_torch, require_torchaudio, require_vision +from transformers.utils import is_vision_available + + +if is_vision_available(): + from transformers import ImageBindImageProcessor, ImageBindProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +global_rng = random.Random() + + +# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list +def floats_list(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + values = [] + for batch_idx in range(shape[0]): + values.append([]) + for _ in range(shape[1]): + values[-1].append(rng.random() * scale) + + return values + + +@require_vision +@require_torchaudio +class ImageBindProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = ImageBindProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + self.checkpoint = "EduardoPacheco/imagebind-huge" + + image_processor = ImageBindImageProcessor() + tokenizer_slow = CLIPTokenizer.from_pretrained(self.checkpoint) + tokenizer_fast = CLIPTokenizerFast.from_pretrained(self.checkpoint) + feature_extractor = ImageBindFeatureExtractor() + + processor_slow = ImageBindProcessor(image_processor, tokenizer_slow, feature_extractor) + processor_fast = ImageBindProcessor(image_processor, tokenizer_fast, feature_extractor) + + processor_slow.save_pretrained(self.tmpdirname) + processor_fast.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return CLIPTokenizer.from_pretrained(self.checkpoint, **kwargs) + + def get_rust_tokenizer(self, **kwargs): + return CLIPTokenizerFast.from_pretrained(self.checkpoint, **kwargs) + + def get_image_processor(self, **kwargs): + return ImageBindImageProcessor.from_pretrained(self.checkpoint, **kwargs) + + def get_feature_extractor(self, **kwargs): + return ImageBindFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_audio_inputs(self): + return [np.random.rand(1500)] + + def test_save_load_pretrained_default(self): + tokenizer_slow = self.get_tokenizer() + tokenizer_fast = self.get_rust_tokenizer() + image_processor = self.get_image_processor() + feature_extractor = self.get_feature_extractor() + + processor_slow = ImageBindProcessor( + tokenizer=tokenizer_slow, image_processor=image_processor, feature_extractor=feature_extractor + ) + processor_slow.save_pretrained(self.tmpdirname) + processor_slow = ImageBindProcessor.from_pretrained(self.tmpdirname, use_fast=False) + + processor_fast = ImageBindProcessor( + tokenizer=tokenizer_fast, image_processor=image_processor, feature_extractor=feature_extractor + ) + processor_fast.save_pretrained(self.tmpdirname) + processor_fast = ImageBindProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab()) + self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab()) + self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab()) + self.assertIsInstance(processor_slow.tokenizer, CLIPTokenizer) + self.assertIsInstance(processor_fast.tokenizer, CLIPTokenizerFast) + + self.assertEqual(processor_slow.image_processor.to_json_string(), image_processor.to_json_string()) + self.assertEqual(processor_fast.image_processor.to_json_string(), image_processor.to_json_string()) + self.assertIsInstance(processor_slow.image_processor, ImageBindImageProcessor) + self.assertIsInstance(processor_fast.image_processor, ImageBindImageProcessor) + + self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor_slow.feature_extractor, ImageBindFeatureExtractor) + self.assertIsInstance(processor_fast.feature_extractor, ImageBindFeatureExtractor) + + def test_save_load_pretrained_additional_features(self): + processor = ImageBindProcessor( + tokenizer=self.get_tokenizer(), + image_processor=self.get_image_processor(), + feature_extractor=self.get_feature_extractor(), + ) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + # Need to put same kwargs for both image_processor and feature_extractor as they share the same config :/ + image_processor_add_kwargs = self.get_image_processor(do_convert_rgb=False, do_chunk=False, num_chunks=5) + feature_extractor_add_kwargs = self.get_feature_extractor(do_convert_rgb=False, do_chunk=False, num_chunks=5) + + processor = ImageBindProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_convert_rgb=False, do_chunk=False, num_chunks=5 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, ImageBindImageProcessor) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, ImageBindFeatureExtractor) + + def test_image_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = ImageBindProcessor( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=self.get_feature_extractor() + ) + + image_input = self.prepare_image_inputs() + + input_image_proc = image_processor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + for key in input_image_proc.keys(): + self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = ImageBindProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=self.get_image_processor() + ) + + raw_speech = self.prepare_audio_inputs() + + input_feat_extract = feature_extractor(raw_speech, return_tensors="np") + input_processor = processor(audio=raw_speech, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor = ImageBindProcessor( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + + input_str = "lower newer" + + encoded_processor = processor(text=input_str) + + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + def test_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor = ImageBindProcessor( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask", "pixel_values"]) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + def test_tokenizer_decode(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor = ImageBindProcessor( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor = ImageBindProcessor( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + audio_input = self.prepare_audio_inputs() + + inputs = processor(text=input_str, images=image_input, audio=audio_input) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) + + @require_torch + def test_doubly_passed_kwargs_audio(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + image_processor = self.get_component("image_processor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer() + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer"] + raw_speech = floats_list((3, 1000)) + with self.assertRaises(ValueError): + _ = processor( + text=input_str, + audio=raw_speech, + audio_kwargs={"padding": "max_length"}, + padding="max_length", + ) + + @require_torch + def test_unstructured_kwargs_audio(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + image_processor = self.get_component("image_processor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer(max_length=117) + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer", max_length=117) + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + raw_speech = floats_list((3, 1000)) + inputs = processor( + text=input_str, + audio=raw_speech, + return_tensors="pt", + padding="max_length", + max_length=76, + ) + + if "input_ids" in inputs: + self.assertEqual(len(inputs["input_ids"][0]), 76) + elif "labels" in inputs: + self.assertEqual(len(inputs["labels"][0]), 76) + + @require_torch + def test_tokenizer_defaults_preserved_by_kwargs_audio(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + image_processor = self.get_component("image_processor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer(max_length=117, padding="max_length") + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + else: + self.assertTrue(False, "Processor doesn't have get_tokenizer or get_component defined") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + raw_speech = floats_list((3, 1000)) + inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt") + if "input_ids" in inputs: + self.assertEqual(len(inputs["input_ids"][0]), 117) + elif "labels" in inputs: + self.assertEqual(len(inputs["labels"][0]), 117) + + @require_torch + @require_vision + def test_structured_kwargs_audio_nested(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + image_processor = self.get_component("image_processor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer() + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer"] + raw_speech = floats_list((3, 1000)) + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + "audio_kwargs": {"padding": "max_length", "max_length": 66}, + } + + inputs = processor(text=input_str, audio=raw_speech, **all_kwargs) + if "input_ids" in inputs: + self.assertEqual(len(inputs["input_ids"][0]), 76) + elif "labels" in inputs: + self.assertEqual(len(inputs["labels"][0]), 76) + + @require_torch + def test_kwargs_overrides_default_tokenizer_kwargs_audio(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + image_processor = self.get_component("image_processor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer(max_length=117) + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer", max_length=117) + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class( + tokenizer=tokenizer, image_processor=image_processor, feature_extractor=feature_extractor + ) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + raw_speech = floats_list((3, 1000)) + inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length") + if "input_ids" in inputs: + self.assertEqual(len(inputs["input_ids"][0]), 112) + elif "labels" in inputs: + self.assertEqual(len(inputs["labels"][0]), 112) diff --git a/utils/check_repo.py b/utils/check_repo.py index 6872dada3d93..3ce5eaf85f6c 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -252,6 +252,12 @@ "FlavaMultimodalModel", "GPT2DoubleHeadsModel", "GPTSw3DoubleHeadsModel", + "ImageBindTextModel", + "ImageBindTextModelWithProjection", + "ImageBindVisionModel", + "ImageBindVisionModelWithProjection", + "ImageBindAudioModel", + "ImageBindAudioModelWithProjection", "InstructBlipVisionModel", "InstructBlipQFormerModel", "InstructBlipVideoVisionModel",