diff --git a/mteb/models/model_implementations/nvidia_models.py b/mteb/models/model_implementations/nvidia_models.py index b7c232e791..9df229958d 100644 --- a/mteb/models/model_implementations/nvidia_models.py +++ b/mteb/models/model_implementations/nvidia_models.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Callable from typing import Any import torch @@ -29,7 +30,7 @@ }""" -def instruction_template( +def _instruction_template( instruction: str, prompt_type: PromptType | None = None ) -> str: return f"Instruct: {instruction}\nQuery: " if instruction else "" @@ -100,10 +101,77 @@ def instruction_template( "MrTidyRetrieval", } + +class _NVEmbedWrapper(InstructSentenceTransformerModel): + """Inherited, because nvembed requires `sbert==2`, but it doesn't have tokenizers kwargs""" + + def __init__( + self, + model_name: str, + revision: str, + instruction_template: str + | Callable[[str, PromptType | None], str] + | None = None, + max_seq_length: int | None = None, + apply_instruction_to_passages: bool = True, + padding_side: str | None = None, + add_eos_token: bool = False, + prompts_dict: dict[str, str] | None = None, + **kwargs: Any, + ): + from sentence_transformers import __version__ as sbert_version + + required_transformers_version = "4.42.4" + required_sbert_version = "2.7.0" + + if Version(transformers_version) != Version(required_transformers_version): + raise RuntimeError( + f"transformers version {transformers_version} is not match with required " + f"install version {required_transformers_version} to run `nvidia/NV-Embed-v2`" + ) + + if Version(sbert_version) != Version(required_sbert_version): + raise RuntimeError( + f"sbert version {sbert_version} is not match with required " + f"install version {required_sbert_version} to run `nvidia/NV-Embed-v2`" + ) + + requires_package( + self, "flash_attn", model_name, "pip install 'mteb[flash_attention]'" + ) + + from sentence_transformers import SentenceTransformer + + if ( + isinstance(instruction_template, str) + and "{instruction}" not in instruction_template + ): + raise ValueError( + "Instruction template must contain the string '{instruction}'." + ) + if instruction_template is None: + logger.warning( + "No instruction template provided. Instructions will be used as-is." + ) + + self.instruction_template = instruction_template + + self.model_name = model_name + self.model = SentenceTransformer(model_name, revision=revision, **kwargs) + self.model.tokenizer.padding_side = padding_side + self.model.tokenizer.add_eos_token = add_eos_token + + if max_seq_length: + # https://github.com/huggingface/sentence-transformers/issues/3575 + self.model.max_seq_length = max_seq_length + self.apply_instruction_to_passages = apply_instruction_to_passages + self.prompts_dict = prompts_dict + + NV_embed_v2 = ModelMeta( - loader=InstructSentenceTransformerModel, + loader=_NVEmbedWrapper, loader_kwargs=dict( - instruction_template=instruction_template, + instruction_template=_instruction_template, trust_remote_code=True, max_seq_length=32768, padding_side="right", @@ -131,9 +199,9 @@ def instruction_template( ) NV_embed_v1 = ModelMeta( - loader=InstructSentenceTransformerModel, + loader=_NVEmbedWrapper, loader_kwargs=dict( - instruction_template=instruction_template, + instruction_template=_instruction_template, trust_remote_code=True, max_seq_length=32768, padding_side="right",