Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from typing import TYPE_CHECKING, Any
from typing import Any

import torch
from packaging.version import Version
from torch.utils.data import DataLoader
from transformers import __version__ as transformers_version

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models.abs_encoder import AbsEncoder
from mteb.models.model_meta import ModelMeta
from mteb.types import Array, BatchedInput, PromptType

if TYPE_CHECKING:
pass


LLAMA_NEMORETRIEVER_CITATION = """@misc{xu2025llamanemoretrievercolembedtopperforming,
title={Llama Nemoretriever Colembed: Top-Performing Text-Image Retrieval Model},
author={Mengyao Xu and Gabriel Moreira and Ronay Ak and Radek Osmulski and Yauhen Babakhin and Zhiding Yu and Benedikt Schifferer and Even Oldridge},
Expand All @@ -34,6 +32,14 @@ def __init__(
attn_implementation="flash_attention_2",
**kwargs,
):
required_transformers_version = "4.49.0"

if Version(transformers_version) != Version(required_transformers_version):
raise RuntimeError(
f"transformers version {transformers_version} is not match with required "
f"install version {required_transformers_version} to run `nvidia/llama-nemoretriever-colembed`"
)

from transformers import AutoModel

self.model = AutoModel.from_pretrained(
Expand Down
63 changes: 58 additions & 5 deletions mteb/models/model_implementations/nvidia_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from transformers import AutoModel, AutoTokenizer
from transformers import __version__ as transformers_version

from mteb import TaskMetadata
from mteb._requires_package import requires_package
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import CrossEncoderWrapper
from mteb.models.abs_encoder import AbsEncoder
from mteb.models.instruct_wrapper import InstructSentenceTransformerModel
from mteb.models.model_meta import ModelMeta, ScoringFunction
Expand All @@ -20,23 +21,23 @@
logger = logging.getLogger(__name__)

NV_RETRIEVER_CITATION = """@misc{lee2025nvembedimprovedtechniquestraining,
title={NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models},
title={NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models},
author={Chankyu Lee and Rajarshi Roy and Mengyao Xu and Jonathan Raiman and Mohammad Shoeybi and Bryan Catanzaro and Wei Ping},
year={2025},
eprint={2405.17428},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2405.17428},
url={https://arxiv.org/abs/2405.17428},
}"""

LlamaEmbedNemotron_CITATION = """@misc{babakhin2025llamaembednemotron8buniversaltextembedding,
title={Llama-Embed-Nemotron-8B: A Universal Text Embedding Model for Multilingual and Cross-Lingual Tasks},
title={Llama-Embed-Nemotron-8B: A Universal Text Embedding Model for Multilingual and Cross-Lingual Tasks},
author={Yauhen Babakhin and Radek Osmulski and Ronay Ak and Gabriel Moreira and Mengyao Xu and Benedikt Schifferer and Bo Liu and Even Oldridge},
year={2025},
eprint={2511.07025},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2511.07025},
url={https://arxiv.org/abs/2511.07025},
}"""


Expand Down Expand Up @@ -629,3 +630,55 @@ def _extract_embeddings(
contacts=["ybabakhin"],
citation=LlamaEmbedNemotron_CITATION,
)


def _nemotron_rerank_model(model: str, revision: str, **kwargs) -> CrossEncoderWrapper:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this function used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is loader function for nvidia/llama-nemotron-rerank-1b-v2

nemotron_rerank_1b_v2 = ModelMeta(
loader=_nemotron_rerank_model,
loader_kwargs=dict(

required_transformers_version = "4.47.1"

if Version(transformers_version) != Version(required_transformers_version):
raise RuntimeError(
f"transformers version {transformers_version} is not match with required "
f"install version {required_transformers_version} to run `nvidia/llama-nemotron-rerank-1b-v2`"
)

return CrossEncoderWrapper(
model=model,
revision=revision,
**kwargs,
)


nemotron_rerank_1b_v2 = ModelMeta(
loader=_nemotron_rerank_model,
loader_kwargs=dict(
trust_remote_code=True,
query_prefix="question:",
passage_prefix=" \n \n passage:",
model_kwargs={"torch_dtype": torch.float32},
),
name="nvidia/llama-nemotron-rerank-1b-v2",
revision="78efcfdc23b53a753f6c73f2d78b18132a34ac4d",
release_date="2025-10-16",
languages=["eng-Latn"],
n_parameters=1235816448,
memory_usage_mb=2357.0,
max_tokens=4096,
embed_dim=2048,
license="https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/",
open_weights=True,
public_training_code=None,
public_training_data=None,
framework=["PyTorch", "Sentence Transformers"],
reference="https://huggingface.co/nvidia/llama-nemotron-rerank-1b-v2",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=None,
training_datasets=set(
# private
),
adapted_from="meta-llama/Llama-3.2-1B",
superseded_by=None,
modalities=["text"],
model_type=["cross-encoder"],
citation=None,
contacts=None,
)
2 changes: 1 addition & 1 deletion mteb/models/model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _from_hub(
revision = revisions[0].commit_id if revisions else None

release_date = cls.fetch_release_date(model_name)
model_license = card_data.license
model_license = card_data.license if card_data.license != "other" else None
n_parameters = cls._calculate_num_parameters_from_hub(model_name)
memory_usage_mb = cls._calculate_memory_usage_mb(model_name, n_parameters)
if model_config and hasattr(model_config, "hidden_size"):
Expand Down
19 changes: 16 additions & 3 deletions mteb/models/sentence_transformer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,24 @@ def encode(


class CrossEncoderWrapper:
"""Wrapper for CrossEncoder models."""
"""Wrapper for CrossEncoder models.

Args:
model: The CrossEncoder model to use. Can be a string (model name) or a CrossEncoder model.
revision: The revision of the model to use.
device: The device used to load the model.
query_prefix: A prefix to add to all queries.
passage_prefix: A prefix to add to all passages.
**kwargs: Additional arguments to pass to the CrossEncoder model.
"""

def __init__(
self,
model: CrossEncoder | str,
revision: str | None = None,
device: str | None = None,
query_prefix: str = "",
passage_prefix: str = "",
**kwargs,
) -> None:
from sentence_transformers import CrossEncoder
Expand All @@ -283,6 +294,8 @@ def __init__(
self.model = CrossEncoder(model, revision=revision, device=device, **kwargs)

self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model)
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix

def predict(
self,
Expand Down Expand Up @@ -311,10 +324,10 @@ def predict(
The predicted relevance scores for each inputs pair.
"""
all_queries_with_instructions = [
text for batch in inputs1 for text in batch["text"]
self.query_prefix + text for batch in inputs1 for text in batch["text"]
]
all_corpus_with_instructions = [
text for batch in inputs2 for text in batch["text"]
self.passage_prefix + text for batch in inputs2 for text in batch["text"]
]

return self.model.predict(
Expand Down