Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions mteb/models/model_implementations/nvidia_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Callable
from typing import Any

import torch
Expand Down Expand Up @@ -29,7 +30,7 @@
}"""


def instruction_template(
def _instruction_template(
instruction: str, prompt_type: PromptType | None = None
) -> str:
return f"Instruct: {instruction}\nQuery: " if instruction else ""
Expand Down Expand Up @@ -100,10 +101,77 @@ def instruction_template(
"MrTidyRetrieval",
}


class _NVEmbedWrapper(InstructSentenceTransformerModel):
"""Inherited, because nvembed requires `sbert==2`, but it doesn't have tokenizers kwargs"""

def __init__(
self,
model_name: str,
revision: str,
instruction_template: str
| Callable[[str, PromptType | None], str]
| None = None,
max_seq_length: int | None = None,
apply_instruction_to_passages: bool = True,
padding_side: str | None = None,
add_eos_token: bool = False,
prompts_dict: dict[str, str] | None = None,
**kwargs: Any,
):
from sentence_transformers import __version__ as sbert_version

required_transformers_version = "4.42.4"
required_sbert_version = "2.7.0"

if Version(transformers_version) != Version(required_transformers_version):
raise RuntimeError(
f"transformers version {transformers_version} is not match with required "
f"install version {required_transformers_version} to run `nvidia/NV-Embed-v2`"
)

if Version(sbert_version) != Version(required_sbert_version):
raise RuntimeError(
f"sbert version {sbert_version} is not match with required "
f"install version {required_sbert_version} to run `nvidia/NV-Embed-v2`"
)

requires_package(
self, "flash_attn", model_name, "pip install 'mteb[flash_attention]'"
)

from sentence_transformers import SentenceTransformer

if (
isinstance(instruction_template, str)
and "{instruction}" not in instruction_template
):
raise ValueError(
"Instruction template must contain the string '{instruction}'."
)
if instruction_template is None:
logger.warning(
"No instruction template provided. Instructions will be used as-is."
)

self.instruction_template = instruction_template

self.model_name = model_name
self.model = SentenceTransformer(model_name, revision=revision, **kwargs)
self.model.tokenizer.padding_side = padding_side
self.model.tokenizer.add_eos_token = add_eos_token

if max_seq_length:
# https://github.com/huggingface/sentence-transformers/issues/3575
self.model.max_seq_length = max_seq_length
self.apply_instruction_to_passages = apply_instruction_to_passages
self.prompts_dict = prompts_dict


NV_embed_v2 = ModelMeta(
loader=InstructSentenceTransformerModel,
loader=_NVEmbedWrapper,
loader_kwargs=dict(
instruction_template=instruction_template,
instruction_template=_instruction_template,
trust_remote_code=True,
max_seq_length=32768,
padding_side="right",
Expand Down Expand Up @@ -131,9 +199,9 @@ def instruction_template(
)

NV_embed_v1 = ModelMeta(
loader=InstructSentenceTransformerModel,
loader=_NVEmbedWrapper,
loader_kwargs=dict(
instruction_template=instruction_template,
instruction_template=_instruction_template,
trust_remote_code=True,
max_seq_length=32768,
padding_side="right",
Expand Down
Loading