diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 07aad5be5b57..787aa2592aa9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -632,6 +632,8 @@ title: Jamba - local: model_doc/jetmoe title: JetMoe + - local: model_doc/jina_embeddings_v3 + title: jina_embeddings_v3 - local: model_doc/led title: LED - local: model_doc/lfm2 diff --git a/docs/source/en/model_doc/jina_embeddings_v3.md b/docs/source/en/model_doc/jina_embeddings_v3.md new file mode 100644 index 000000000000..e30af7e50ea9 --- /dev/null +++ b/docs/source/en/model_doc/jina_embeddings_v3.md @@ -0,0 +1,165 @@ + + +*This model was released on 2024-09-16 and added to Hugging Face Transformers on 2026-03-18.* + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ + +# JinaEmbeddingsV3 + +The [Jina-Embeddings-v3](https://huggingface.co/papers/2409.10173) is a multilingual, multi-task text embedding model designed for a variety of NLP applications. Based on the XLM-RoBERTa architecture, this model supports **Rotary Position Embeddings (RoPE)** replacing absolute position embeddings to support long input sequences up to 8192 tokens. Additionally, it features 5 built-in **Task-Specific LoRA Adapters:** that allow the model to generate task-specific embeddings (e.g., for retrieval vs. classification) without increasing inference latency significantly. + + +You can find the original Jina Embeddings v3 checkpoints under the [Jina AI](https://huggingface.co/jinaai) organization. + + +> [!TIP] +> Click on the Jina Embeddings v3 models in the right sidebar for more examples of how to apply the model to different language tasks. + +The example below demonstrates how to extract features (embeddings) with [`Pipeline`], [`AutoModel`], and from the command line. + + + + +```py +import torch +from transformers import pipeline + +pipeline = pipeline( + task="feature-extraction", + model="jinaai/jina-embeddings-v3-hf", +) +# Returns a list of lists containing the embeddings for each token +embeddings = pipeline("Jina Embeddings V3 is great for semantic search.") +``` + + + + + + +```py +import torch +from transformers import AutoModel, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3-hf") +model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3-hf", device_map="auto") + +prompt = "Jina Embeddings V3 is great for semantic search." +inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + +with torch.no_grad(): + outputs = model(**inputs) + # The base AutoModel returns the raw hidden states for all tokens + last_hidden_states = outputs.last_hidden_state + +print(f"Features shape: {last_hidden_states.shape}") +``` + + + + +## Task-Specific LoRA Adapters + +A key feature of `JinaEmbeddingsV3` is it's LoRA adapters, which allow you to tailor the output embeddings to specific useful use cases without the overhead of loading entirely different models. + +The following tasks are supported: + +* **`retrieval.query`**: Used for query embeddings in asymmetric retrieval tasks (e.g., search queries). +* **`retrieval.passage`**: Used for passage embeddings in asymmetric retrieval tasks (e.g., the documents being searched). +* **`separation`**: Used for embeddings in clustering and re-ranking applications. +* **`classification`**: Used for embeddings in classification tasks. +* **`text-matching`**: Used for embeddings in tasks that quantify similarity between two texts, such as Semantic Textual Similarity (STS) or symmetric retrieval tasks. + + +To generate high-quality sentence or paragraph embeddings, you need to apply **mean pooling** to the model's token embeddings. Mean pooling takes all token embeddings from the model's output and averages them, masking out the padding tokens. + +Here is how you can generate sentence embeddings tailored for a retrieval query task using the `AutoModel` API. + +```python +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, AutoModel + +def mean_pooling(model_output, attention_mask): + # First element of model_output contains all token embeddings + token_embeddings = model_output[0] + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + + # Sum the embeddings and divide by the number of non-padding tokens + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return sum_embeddings / sum_mask + + +sentences = [ + "How is the weather today?", + "What is the current weather like today?" +] + +tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3-hf") +model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3-hf") + +encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt").to(model.device) + +# Set up the adapter mask for your specific task +task = 'retrieval_query' # Can be any of (retrieval_passage, separation, classification, text_matching) depending on the use-case. + +model.load_adapter("jinaai/jina-embeddings-v3-hf", adapter_name=task, adapter_kwargs={"subfolder": task}) + +model.set_adapter(task) + +with torch.no_grad(): + model_output = model(**encoded_input) + +embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) +embeddings = F.normalize(embeddings, p=2, dim=1) + +print(embeddings.shape) +# Output: torch.Size([2, 1024]) +``` + + +## JinaEmbeddingsV3Config + +[[autodoc]] JinaEmbeddingsV3Config + +## JinaEmbeddingsV3Model + +[[autodoc]] JinaEmbeddingsV3Model + - forward + +## JinaEmbeddingsV3ForMaskedLM + +[[autodoc]] JinaEmbeddingsV3ForMaskedLM + - forward + +## JinaEmbeddingsV3ForSequenceClassification + +[[autodoc]] JinaEmbeddingsV3ForSequenceClassification + - forward + +## JinaEmbeddingsV3ForTokenClassification + +[[autodoc]] JinaEmbeddingsV3ForTokenClassification + - forward + +## JinaEmbeddingsV3ForQuestionAnswering + +[[autodoc]] JinaEmbeddingsV3ForQuestionAnswering + - forward diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 22d9ddcbcd03..c7452fc89d20 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -420,6 +420,22 @@ def _build_checkpoint_conversion_mapping(): target_patterns="LayerNorm.bias", ), ], + "jina_embeddings_v3": [ + WeightRenaming(source_patterns="emb_ln", target_patterns="embeddings.LayerNorm"), + WeightRenaming(source_patterns="encoder.layers", target_patterns="layers"), + WeightConverter( + source_patterns="mixer.Wqkv", + target_patterns=[ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + operations=[Chunk(dim=0)], + ), + WeightRenaming(source_patterns="mixer.out_proj", target_patterns="self_attn.o_proj"), + WeightRenaming(source_patterns="norm1", target_patterns="post_attention_layernorm"), + WeightRenaming(source_patterns="norm2", target_patterns="post_mlp_layernorm"), + ], } mapping["legacy"] += [ WeightRenaming( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5f45081ac4a0..860a1bac23cf 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -203,6 +203,7 @@ from .jamba import * from .janus import * from .jetmoe import * + from .jina_embeddings_v3 import * from .kosmos2 import * from .kosmos2_5 import * from .kyutai_speech_to_text import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 476b5362343f..3893ca8838c9 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -237,6 +237,7 @@ ("jamba", "JambaConfig"), ("janus", "JanusConfig"), ("jetmoe", "JetMoeConfig"), + ("jina_embeddings_v3", "JinaEmbeddingsV3Config"), ("kosmos-2", "Kosmos2Config"), ("kosmos-2.5", "Kosmos2_5Config"), ("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"), @@ -741,6 +742,7 @@ ("jamba", "Jamba"), ("janus", "Janus"), ("jetmoe", "JetMoe"), + ("jina_embeddings_v3", "JinaEmbeddingsV3"), ("kosmos-2", "KOSMOS-2"), ("kosmos-2.5", "KOSMOS-2.5"), ("kyutai_speech_to_text", "KyutaiSpeechToText"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 764d3b770e86..afb0658a456c 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -234,6 +234,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("jamba", "JambaModel"), ("janus", "JanusModel"), ("jetmoe", "JetMoeModel"), + ("jina_embeddings_v3", "JinaEmbeddingsV3Model"), ("kosmos-2", "Kosmos2Model"), ("kosmos-2.5", "Kosmos2_5Model"), ("kyutai_speech_to_text", "KyutaiSpeechToTextModel"), @@ -1049,6 +1050,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("fnet", "FNetForMaskedLM"), ("funnel", "FunnelForMaskedLM"), ("ibert", "IBertForMaskedLM"), + ("jina_embeddings_v3", "JinaEmbeddingsV3ForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"), ("longformer", "LongformerForMaskedLM"), ("luke", "LukeForMaskedLM"), @@ -1232,6 +1234,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ibert", "IBertForSequenceClassification"), ("jamba", "JambaForSequenceClassification"), ("jetmoe", "JetMoeForSequenceClassification"), + ("jina_embeddings_v3", "JinaEmbeddingsV3ForSequenceClassification"), ("layoutlm", "LayoutLMForSequenceClassification"), ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), @@ -1331,6 +1334,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gpt_neox", "GPTNeoXForQuestionAnswering"), ("gptj", "GPTJForQuestionAnswering"), ("ibert", "IBertForQuestionAnswering"), + ("jina_embeddings_v3", "JinaEmbeddingsV3ForQuestionAnswering"), ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), ("led", "LEDForQuestionAnswering"), @@ -1447,6 +1451,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gpt_oss", "GptOssForTokenClassification"), ("helium", "HeliumForTokenClassification"), ("ibert", "IBertForTokenClassification"), + ("jina_embeddings_v3", "JinaEmbeddingsV3ForTokenClassification"), ("layoutlm", "LayoutLMForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 4faaa9844315..46076b7f223f 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -161,6 +161,7 @@ ("instructblipvideo", "GPT2Tokenizer" if is_tokenizers_available() else None), ("internvl", "Qwen2Tokenizer" if is_tokenizers_available() else None), ("jais2", "GPT2Tokenizer" if is_tokenizers_available() else None), + ("jina_embeddings_v3", "XLMRobertaTokenizer" if is_tokenizers_available() else None), ("kosmos-2", "XLMRobertaTokenizer" if is_tokenizers_available() else None), ("lasr_ctc", "LasrTokenizer" if is_tokenizers_available() else None), ("lasr_encoder", "LasrTokenizer" if is_tokenizers_available() else None), diff --git a/src/transformers/models/jina_embeddings_v3/__init__.py b/src/transformers/models/jina_embeddings_v3/__init__.py new file mode 100644 index 000000000000..c0c33e9ff015 --- /dev/null +++ b/src/transformers/models/jina_embeddings_v3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_jina_embeddings_v3 import * + from .modeling_jina_embeddings_v3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/jina_embeddings_v3/configuration_jina_embeddings_v3.py b/src/transformers/models/jina_embeddings_v3/configuration_jina_embeddings_v3.py new file mode 100644 index 000000000000..fea8d7adbf02 --- /dev/null +++ b/src/transformers/models/jina_embeddings_v3/configuration_jina_embeddings_v3.py @@ -0,0 +1,72 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/jina_embeddings_v3/modular_jina_embeddings_v3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_jina_embeddings_v3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The Jina-AI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="jinaai/jina-embeddings-v3-hf") +@strict(accept_kwargs=True) +class JinaEmbeddingsV3Config(PreTrainedConfig): + r""" + Examples: + + ```python + >>> from transformers import JinaEmbeddingsV3Config, JinaEmbeddingsV3Model + + >>> # Initializing a Jina-Embeddings-V3 jinaai/jina-embeddings-v3-hf style configuration + >>> configuration = JinaEmbeddingsV3Config() + + >>> # Initializing a model (with random weights) from the jinaai/jina-embeddings-v3-hf style configuration + >>> model = JinaEmbeddingsV3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "jina_embeddings_v3" + + vocab_size: int = 250002 + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + intermediate_size: int = 4096 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + max_position_embeddings: int = 8194 + type_vocab_size: int = 1 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-5 + pad_token_id: int | None = 1 + bos_token_id: int | None = 0 + eos_token_id: int | None = 2 + use_cache: bool = True + classifier_dropout: float | int | None = None + tie_word_embeddings: bool = True + default_theta = 20000.0 + rope_parameters: RopeParameters | dict | None = None + + +__all__ = ["JinaEmbeddingsV3Config"] diff --git a/src/transformers/models/jina_embeddings_v3/modeling_jina_embeddings_v3.py b/src/transformers/models/jina_embeddings_v3/modeling_jina_embeddings_v3.py new file mode 100644 index 000000000000..547fb54fa245 --- /dev/null +++ b/src/transformers/models/jina_embeddings_v3/modeling_jina_embeddings_v3.py @@ -0,0 +1,822 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/jina_embeddings_v3/modular_jina_embeddings_v3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_jina_embeddings_v3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The Jina-AI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Optional + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ... import initialization as init +from ...activations import ACT2FN, gelu +from ...integrations import use_kernel_func_from_hub, use_kernelized_func +from ...masking_utils import create_bidirectional_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPooling, + MaskedLMOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring +from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_jina_embeddings_v3 import JinaEmbeddingsV3Config + + +class JinaEmbeddingsV3Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config: JinaEmbeddingsV3Config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> torch.Tensor: + embeddings = inputs_embeds + if inputs_embeds is None: + embeddings = self.word_embeddings(input_ids) + + input_shape = embeddings.shape[:-1] + device = embeddings.device + + if position_ids is None: + position_ids = self.position_ids[:, : input_shape[1]] + + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0]) + buffered_token_type_ids = self.token_type_ids.expand(input_shape[0], -1) + buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids) + token_type_ids = buffered_token_type_ids + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + +class JinaEmbeddingsV3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: JinaEmbeddingsV3Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: JinaEmbeddingsV3Config | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@use_kernelized_func(apply_rotary_pos_emb) +class JinaEmbeddingsV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: JinaEmbeddingsV3Config): + super().__init__() + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_probs_dropout_prob + self.is_causal = False + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class JinaEmbeddingsV3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class JinaEmbeddingsV3Layer(GradientCheckpointingLayer): + def __init__(self, config: JinaEmbeddingsV3Config): + super().__init__() + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.post_attention_dropout = nn.Dropout(config.hidden_dropout_prob) + self.post_mlp_dropout = nn.Dropout(config.hidden_dropout_prob) + self.mlp = JinaEmbeddingsV3MLP(config) + self.self_attn = JinaEmbeddingsV3Attention(config=config) + self.post_mlp_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + attention_output, _ = self.self_attn( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.post_attention_dropout(attention_output) + hidden_states = self.post_attention_layernorm(hidden_states) + + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.post_mlp_dropout(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + return hidden_states + + +class JinaEmbeddingsV3Pooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class JinaEmbeddingsV3PreTrainedModel(PreTrainedModel): + config_class = JinaEmbeddingsV3Config + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": JinaEmbeddingsV3Layer, + "attentions": JinaEmbeddingsV3Attention, + } + + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + super()._init_weights(module) + if isinstance(module, JinaEmbeddingsV3LMHead): + init.zeros_(module.bias) + elif isinstance(module, JinaEmbeddingsV3Embeddings): + init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1))) + init.zeros_(module.token_type_ids) + + +@auto_docstring +class JinaEmbeddingsV3Model(JinaEmbeddingsV3PreTrainedModel): + _no_split_modules = ["JinaEmbeddingsV3Embeddings", "JinaEmbeddingsV3Layer"] + + def __init__(self, config: JinaEmbeddingsV3Config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + + self.embeddings = JinaEmbeddingsV3Embeddings(config) + + self.pooler = JinaEmbeddingsV3Pooler(config) if add_pooling_layer else None + self.rotary_emb = JinaEmbeddingsV3RotaryEmbedding(config) + self.layers = nn.ModuleList([JinaEmbeddingsV3Layer(config) for _ in range(config.num_hidden_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling | tuple: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + hidden_states = embedding_output + + if position_ids is None: + # Default RoPE positions assume right padding; left padding requires explicit corrected position_ids for RoPE. + position_ids = torch.arange(hidden_states.shape[1], dtype=torch.long, device=hidden_states.device) + position_ids = position_ids.unsqueeze(0) + + position_embeddings = self.rotary_emb(embedding_output, position_ids) + + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=embedding_output, + attention_mask=attention_mask, + ) + + for encoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + sequence_output = hidden_states + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + +class JinaEmbeddingsV3LMHead(nn.Module): + """JinaEmbeddingsV3 Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + +@auto_docstring +class JinaEmbeddingsV3ForMaskedLM(JinaEmbeddingsV3PreTrainedModel): + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } + + def __init__(self, config): + super().__init__(config=config) + + self.lm_head = JinaEmbeddingsV3LMHead(config) + self.roberta = JinaEmbeddingsV3Model(config, add_pooling_layer=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor] | MaskedLMOutput: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + return_dict=True, + **kwargs, + ) + sequence_output = outputs[0] + + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + # move labels to correct device + labels = labels.to(prediction_scores.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class JinaEmbeddingsV3ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@auto_docstring( + custom_intro=""" + XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class JinaEmbeddingsV3ForSequenceClassification(JinaEmbeddingsV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.classifier = JinaEmbeddingsV3ClassificationHead(config) + + self.roberta = JinaEmbeddingsV3Model(config, add_pooling_layer=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor] | SequenceClassifierOutput: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + return_dict=True, + **kwargs, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class JinaEmbeddingsV3ForTokenClassification(JinaEmbeddingsV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.roberta = JinaEmbeddingsV3Model(config, add_pooling_layer=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor] | TokenClassifierOutput: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + return_dict=True, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class JinaEmbeddingsV3ForQuestionAnswering(JinaEmbeddingsV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.roberta = JinaEmbeddingsV3Model(config, add_pooling_layer=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + start_positions: torch.LongTensor | None = None, + end_positions: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + return_dict=True, + **kwargs, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "JinaEmbeddingsV3PreTrainedModel", + "JinaEmbeddingsV3Model", + "JinaEmbeddingsV3ForMaskedLM", + "JinaEmbeddingsV3ForSequenceClassification", + "JinaEmbeddingsV3ForTokenClassification", + "JinaEmbeddingsV3ForQuestionAnswering", + "JinaEmbeddingsV3Layer", +] diff --git a/src/transformers/models/jina_embeddings_v3/modular_jina_embeddings_v3.py b/src/transformers/models/jina_embeddings_v3/modular_jina_embeddings_v3.py new file mode 100644 index 000000000000..bb38c96b687e --- /dev/null +++ b/src/transformers/models/jina_embeddings_v3/modular_jina_embeddings_v3.py @@ -0,0 +1,404 @@ +# Copyright 2026 The Jina-AI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import torch +from huggingface_hub.dataclasses import strict +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...integrations import use_kernelized_func +from ...masking_utils import create_bidirectional_mask +from ...modeling_outputs import ( + BaseModelOutputWithPooling, + MaskedLMOutput, +) +from ...modeling_rope_utils import RopeParameters +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.generic import can_return_tuple, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..clip.modeling_clip import CLIPMLP +from ..gpt_neox.modeling_gpt_neox import GPTNeoXLayer +from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb +from ..xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from ..xlm_roberta.modeling_xlm_roberta import ( + XLMRobertaEmbeddings, + XLMRobertaForMaskedLM, + XLMRobertaForQuestionAnswering, + XLMRobertaForSequenceClassification, + XLMRobertaForTokenClassification, + XLMRobertaLMHead, + XLMRobertaModel, + XLMRobertaPooler, + XLMRobertaPreTrainedModel, + eager_attention_forward, +) + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="jinaai/jina-embeddings-v3-hf") +@strict(accept_kwargs=True) +class JinaEmbeddingsV3Config(XLMRobertaConfig): + r""" + Examples: + + ```python + >>> from transformers import JinaEmbeddingsV3Config, JinaEmbeddingsV3Model + + >>> # Initializing a Jina-Embeddings-V3 jinaai/jina-embeddings-v3-hf style configuration + >>> configuration = JinaEmbeddingsV3Config() + + >>> # Initializing a model (with random weights) from the jinaai/jina-embeddings-v3-hf style configuration + >>> model = JinaEmbeddingsV3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "jina_embeddings_v3" + default_theta = 20000.0 + + vocab_size: int = 250002 + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + intermediate_size: int = 4096 + max_position_embeddings: int = 8194 + type_vocab_size: int = 1 + layer_norm_eps: float = 1e-5 + rope_parameters: RopeParameters | dict | None = None + + add_cross_attention = AttributeError() + is_decoder = AttributeError() + + +class JinaEmbeddingsV3Embeddings(XLMRobertaEmbeddings): + def __init__(self, config: JinaEmbeddingsV3Config): + super().__init__(config) + + del self.padding_idx + del self.position_embeddings + + def create_position_ids_from_inputs_embeds(): + raise AttributeError("Not needed for JinaEmbeddingsV3") + + def create_position_ids_from_input_ids(): + raise AttributeError("Not needed for JinaEmbeddingsV3") + + def forward( + self, + input_ids: torch.LongTensor | None = None, + token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> torch.Tensor: + embeddings = inputs_embeds + if inputs_embeds is None: + embeddings = self.word_embeddings(input_ids) + + input_shape = embeddings.shape[:-1] + device = embeddings.device + + if position_ids is None: + position_ids = self.position_ids[:, : input_shape[1]] + + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0]) + buffered_token_type_ids = self.token_type_ids.expand(input_shape[0], -1) + buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids) + token_type_ids = buffered_token_type_ids + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + +class JinaEmbeddingsV3RotaryEmbedding(LlamaRotaryEmbedding): + pass + + +@use_kernelized_func(apply_rotary_pos_emb) +class JinaEmbeddingsV3Attention(LlamaAttention): + def __init__(self, config: JinaEmbeddingsV3Config): + super().__init__(config) + self.is_causal = False + self.attention_dropout = config.attention_probs_dropout_prob + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + + del self.layer_idx + del self.num_key_value_groups + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class JinaEmbeddingsV3MLP(CLIPMLP): + pass + + +class JinaEmbeddingsV3Layer(GPTNeoXLayer): + def __init__(self, config: JinaEmbeddingsV3Config): + super().__init__(config) + self.self_attn = JinaEmbeddingsV3Attention(config=config) + + self.post_attention_dropout = nn.Dropout(config.hidden_dropout_prob) + self.post_mlp_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_mlp_dropout = nn.Dropout(config.hidden_dropout_prob) + + del self.use_parallel_residual + del self.input_layernorm + del self.attention + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + attention_output, _ = self.self_attn( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.post_attention_dropout(attention_output) + hidden_states = self.post_attention_layernorm(hidden_states) + + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.post_mlp_dropout(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + return hidden_states + + +class JinaEmbeddingsV3Pooler(XLMRobertaPooler): + pass + + +class JinaEmbeddingsV3PreTrainedModel(XLMRobertaPreTrainedModel): + _can_record_outputs = { + "hidden_states": JinaEmbeddingsV3Layer, + "attentions": JinaEmbeddingsV3Attention, + } + + +@auto_docstring +class JinaEmbeddingsV3Model(XLMRobertaModel): + def __init__(self, config: JinaEmbeddingsV3Config, add_pooling_layer=True): + super().__init__(config) + self.rotary_emb = JinaEmbeddingsV3RotaryEmbedding(config) + self.layers = nn.ModuleList([JinaEmbeddingsV3Layer(config) for _ in range(config.num_hidden_layers)]) + del self.encoder + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling | tuple: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + hidden_states = embedding_output + + if position_ids is None: + # Default RoPE positions assume right padding; left padding requires explicit corrected position_ids for RoPE. + position_ids = torch.arange(hidden_states.shape[1], dtype=torch.long, device=hidden_states.device) + position_ids = position_ids.unsqueeze(0) + + position_embeddings = self.rotary_emb(embedding_output, position_ids) + + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=embedding_output, + attention_mask=attention_mask, + ) + + for encoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + sequence_output = hidden_states + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + def _create_attention_masks(self): + raise AttributeError("Not needed for JinaEmbeddingsV3") + + +class JinaEmbeddingsV3LMHead(XLMRobertaLMHead): + pass + + +class JinaEmbeddingsV3ForMaskedLM(XLMRobertaForMaskedLM): + def __init__(self, config): + JinaEmbeddingsV3PreTrainedModel.__init__(self, config=config) + + self.lm_head = JinaEmbeddingsV3LMHead(config) + self.roberta = JinaEmbeddingsV3Model(config, add_pooling_layer=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor] | MaskedLMOutput: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + return_dict=True, + **kwargs, + ) + sequence_output = outputs[0] + + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + # move labels to correct device + labels = labels.to(prediction_scores.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class JinaEmbeddingsV3ForSequenceClassification(XLMRobertaForSequenceClassification): + pass + + +class JinaEmbeddingsV3ForTokenClassification(XLMRobertaForTokenClassification): + pass + + +class JinaEmbeddingsV3ForQuestionAnswering(XLMRobertaForQuestionAnswering): + pass + + +__all__ = [ + "JinaEmbeddingsV3Config", + "JinaEmbeddingsV3PreTrainedModel", + "JinaEmbeddingsV3Model", + "JinaEmbeddingsV3ForMaskedLM", + "JinaEmbeddingsV3ForSequenceClassification", + "JinaEmbeddingsV3ForTokenClassification", + "JinaEmbeddingsV3ForQuestionAnswering", + "JinaEmbeddingsV3Layer", +] diff --git a/tests/models/jina_embeddings_v3/__init__.py b/tests/models/jina_embeddings_v3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/jina_embeddings_v3/test_modeling_jina_embeddings_v3.py b/tests/models/jina_embeddings_v3/test_modeling_jina_embeddings_v3.py new file mode 100644 index 000000000000..a2395613ba4b --- /dev/null +++ b/tests/models/jina_embeddings_v3/test_modeling_jina_embeddings_v3.py @@ -0,0 +1,432 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from transformers import AutoModel, AutoTokenizer, is_torch_available +from transformers.models.jina_embeddings_v3 import JinaEmbeddingsV3Config +from transformers.testing_utils import ( + cleanup, + require_torch, + slow, + torch_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + JinaEmbeddingsV3ForMaskedLM, + JinaEmbeddingsV3ForQuestionAnswering, + JinaEmbeddingsV3ForSequenceClassification, + JinaEmbeddingsV3ForTokenClassification, + JinaEmbeddingsV3Model, + ) + + +class JinaEmbeddingsV3ModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=32, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=20, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=8, + type_vocab_size=1, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels + + def get_config(self): + return JinaEmbeddingsV3Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): + model = JinaEmbeddingsV3Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + result = model(input_ids, token_type_ids=token_type_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels + ): + model = JinaEmbeddingsV3ForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels + ): + model = JinaEmbeddingsV3ForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels + ): + config.num_labels = self.num_labels + model = JinaEmbeddingsV3ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + ): + config.num_labels = self.num_labels + model = JinaEmbeddingsV3ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + } + return config, inputs_dict + + +@require_torch +class JinaEmbeddingsV3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + JinaEmbeddingsV3Model, + JinaEmbeddingsV3ForMaskedLM, + JinaEmbeddingsV3ForQuestionAnswering, + JinaEmbeddingsV3ForSequenceClassification, + JinaEmbeddingsV3ForTokenClassification, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": JinaEmbeddingsV3Model, + "fill-mask": JinaEmbeddingsV3ForMaskedLM, + "text-classification": JinaEmbeddingsV3ForSequenceClassification, + "token-classification": JinaEmbeddingsV3ForTokenClassification, + "zero-shot": JinaEmbeddingsV3ForSequenceClassification, + } + if is_torch_available() + else {} + ) + + def setUp(self): + self.model_tester = JinaEmbeddingsV3ModelTester(self) + self.config_tester = ConfigTester(self, config_class=JinaEmbeddingsV3Config, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + +@require_torch +class JinaEmbeddingsV3ModelIntegrationTest(unittest.TestCase): + model_id = "jinaai/jina-embeddings-v3-hf" + prompt = "Jina Embeddings V3 is great for semantic search." + + def setup(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def _prepare_inputs(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + inputs = tokenizer(self.prompt, return_tensors="pt", padding=True) + return inputs + + @slow + def test_inference_no_head_absolute_embedding(self): + model = AutoModel.from_pretrained(self.model_id, dtype=torch.float32) + model.eval() + inputs = self._prepare_inputs() + + with torch.no_grad(): + output = model(**inputs)[0] + + expected_shape = torch.Size((1, 17, 1024)) + self.assertEqual(output.shape, expected_shape) + + expected_slice = torch.tensor( + [ + [ + [-3.1011, 0.8560, -0.2491, 0.9427, 1.4015, -1.1527, 1.3804, -0.5453, -1.8164], + [-3.1108, 1.0107, -0.2097, 1.3495, 0.9984, -0.9518, 1.3189, -0.6295, -2.1128], + [-2.7095, 0.6469, -0.4475, 1.1364, 1.5975, -0.7545, 1.0803, 0.5199, -2.3569], + ] + ] + ) + + torch.testing.assert_close(output[:, 1:4, 1:10], expected_slice, rtol=1e-4, atol=1e-4) + + @slow + def test_inference_retrieval_query_adapter(self): + task = "retrieval_query" + model = AutoModel.from_pretrained(self.model_id, dtype=torch.float32) + model.load_adapter(self.model_id, adapter_name=task, adapter_kwargs={"subfolder": task}) + model.set_adapter(task) + model.eval() + inputs = self._prepare_inputs() + + with torch.no_grad(): + output = model(**inputs)[0] + + self.assertEqual(output.shape, torch.Size((1, 17, 1024))) + expected_slice = torch.tensor( + [ + [ + [-1.9765, 0.7356, -0.4414, 0.5823, 2.1507, -0.8906, 0.0233, -0.2389, -1.5708], + [-2.0078, 0.9562, -0.3315, 1.0080, 1.8247, -0.6678, -0.2505, -0.3441, -1.9328], + [-1.9107, 0.7120, -0.4675, 0.9436, 2.1607, -0.4170, -0.1513, 1.0063, -2.0103], + ] + ] + ) + + torch.testing.assert_close(output[:, 1:4, 1:10], expected_slice, rtol=1e-4, atol=1e-4) + + @slow + def test_inference_retrieval_passage_adapter(self): + task = "retrieval_passage" + model = AutoModel.from_pretrained(self.model_id, dtype=torch.float32) + model.load_adapter(self.model_id, adapter_name=task, adapter_kwargs={"subfolder": task}) + model.set_adapter(task) + model.eval() + inputs = self._prepare_inputs() + + with torch.no_grad(): + output = model(**inputs)[0] + + expected_shape = torch.Size((1, 17, 1024)) + self.assertEqual(output.shape, expected_shape) + + expected_slice = torch.tensor( + [ + [ + [-1.7028, 0.5688, -0.8541, 0.4696, 2.5396, -0.8374, -0.1404, -0.3123, -1.4636], + [-1.6631, 0.6571, -0.8641, 0.9177, 2.3502, -0.6578, -0.3763, -0.3975, -1.7684], + [-1.4739, 0.4739, -0.8745, 0.8812, 2.6848, -0.4496, -0.4964, 0.6403, -2.0821], + ] + ] + ) + + torch.testing.assert_close(output[:, 1:4, 1:10], expected_slice, rtol=1e-4, atol=1e-4) + + @slow + def test_inference_separation_adapter(self): + task = "separation" + model = AutoModel.from_pretrained(self.model_id, dtype=torch.float32) + model.load_adapter(self.model_id, adapter_name=task, adapter_kwargs={"subfolder": task}) + model.set_adapter(task) + model.eval() + + inputs = self._prepare_inputs() + + with torch.no_grad(): + output = model(**inputs)[0] + + self.assertEqual(output.shape, torch.Size((1, 17, 1024))) + expected_slice = torch.tensor( + [ + [ + [-3.0336, 1.4392, 0.2875, 0.7660, 0.7054, -1.1701, 1.6121, -0.6325, -1.5177], + [-3.0875, 1.5134, 0.3620, 1.0281, 0.4895, -1.0484, 1.6574, -0.7636, -1.6736], + [-2.7605, 1.2920, 0.2223, 0.9895, 0.8515, -0.9050, 1.5558, 0.1410, -1.8531], + ] + ] + ) + + torch.testing.assert_close(output[:, 1:4, 1:10], expected_slice, rtol=1e-4, atol=1e-4) + + @slow + def test_inference_classification_adapter(self): + task = "classification" + model = AutoModel.from_pretrained(self.model_id, dtype=torch.float32) + model.load_adapter(self.model_id, adapter_name=task, adapter_kwargs={"subfolder": task}) + model.set_adapter(task) + model.eval() + + inputs = self._prepare_inputs() + + with torch.no_grad(): + output = model(**inputs)[0] + + self.assertEqual(output.shape, torch.Size((1, 17, 1024))) + expected_slice = torch.tensor( + [ + [ + [-2.7150, 0.2485, 1.2297, 0.6988, 0.9804, -1.2831, 1.3446, -0.1663, -0.6874], + [-2.8101, 0.1711, 1.2010, 0.9873, 0.5092, -1.3312, 1.4633, -0.2467, -0.7835], + [-2.6067, 0.2362, 0.6945, 1.0134, 0.7105, -1.3767, 0.9999, 0.4427, -1.1153], + ] + ] + ) + + torch.testing.assert_close(output[:, 1:4, 1:10], expected_slice, rtol=1e-4, atol=1e-4) + + @slow + def test_inference_text_matching_adapter(self): + task = "text_matching" + model = AutoModel.from_pretrained(self.model_id, dtype=torch.float32) + model.load_adapter(self.model_id, adapter_name=task, adapter_kwargs={"subfolder": task}) + model.set_adapter(task) + model.eval() + + inputs = self._prepare_inputs() + + with torch.no_grad(): + output = model(**inputs)[0] + + self.assertEqual(output.shape, torch.Size((1, 17, 1024))) + expected_slice = torch.tensor( + [ + [ + [-1.5888, 1.0527, 0.1237, -0.0822, 1.6507, -1.0371, -0.8815, -0.8082, -0.6564], + [-1.6529, 1.3143, 0.1957, 0.2914, 1.4897, -0.8735, -1.0067, -0.7544, -1.0513], + [-1.5308, 1.4805, -0.1393, 0.3879, 1.4373, -0.6064, -1.6436, 0.4793, -1.3388], + ] + ] + ) + + torch.testing.assert_close(output[:, 1:4, 1:10], expected_slice, rtol=1e-4, atol=1e-4)