Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mteb/models/get_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def get_model_metas(


def get_model(
model_name: str, revision: str | None = None, **kwargs: Any
model_name: str,
revision: str | None = None,
device: str | None = None,
**kwargs: Any,
) -> MTEBModels:
"""A function to fetch and load model object by name.

Expand All @@ -85,13 +88,14 @@ def get_model(
Args:
model_name: Name of the model to fetch
revision: Revision of the model to fetch
device: Device used to load the model
**kwargs: Additional keyword arguments to pass to the model loader

Returns:
A model object
"""
meta = get_model_meta(model_name, revision)
model = meta.load_model(**kwargs)
model = meta.load_model(device=device, **kwargs)

if kwargs:
logger.info(
Expand Down
20 changes: 17 additions & 3 deletions mteb/models/instruct_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def instruct_wrapper(
model_name_or_path: str,
mode: str,
instruction_template: str | Callable[[str, PromptType | None], str] | None = None,
device: str | None = None,
**kwargs,
):
"""Instruct wrapper for models. Uses GritLM to pass instructions to the model.
Expand All @@ -28,6 +29,7 @@ def instruct_wrapper(
model_name_or_path: Model name or path.
mode: Mode of the model. Either 'query' or 'passage'.
instruction_template: Instruction template. Should contain the string '{instruction}'.
device: Device used to load the model.
**kwargs: Additional arguments to pass to the model.
"""
requires_package(
Expand All @@ -40,6 +42,7 @@ def __init__(
self,
model_name_or_path: str,
mode: str,
device: str | None = None,
instruction_template: str
| Callable[[str, PromptType | None], str]
| None = None,
Expand All @@ -63,7 +66,12 @@ def __init__(
)

self.instruction_template = instruction_template
super().__init__(model_name_or_path=model_name_or_path, mode=mode, **kwargs)
super().__init__(
model_name_or_path=model_name_or_path,
mode=mode,
device=device,
**kwargs,
)

def encode(
self,
Expand Down Expand Up @@ -95,7 +103,9 @@ def encode(
embeddings = embeddings.cpu().detach().float().numpy()
return embeddings

return InstructGritLMModel(model_name_or_path, mode, instruction_template, **kwargs)
return InstructGritLMModel(
model_name_or_path, mode, instruction_template=instruction_template, **kwargs
)


class InstructSentenceTransformerModel(AbsEncoder):
Expand All @@ -105,6 +115,7 @@ def __init__(
self,
model_name: str,
revision: str,
device: str | None = None,
instruction_template: str
| Callable[[str, PromptType | None], str]
| None = None,
Expand All @@ -122,6 +133,7 @@ def __init__(
Arguments:
model_name: Model name of the sentence transformers model.
revision: Revision of the sentence transformers model.
device: Device used to load the model.
instruction_template: Model template. Should contain the string '{instruction}'.
max_seq_length: Maximum sequence length. If None, the maximum sequence length will be read from the model config.
apply_instruction_to_passages: Whether to apply the instruction template to the passages.
Expand Down Expand Up @@ -158,7 +170,9 @@ def __init__(
kwargs.setdefault("tokenizer_kwargs", {}).update(tokenizer_params)

self.model_name = model_name
self.model = SentenceTransformer(model_name, revision=revision, **kwargs)
self.model = SentenceTransformer(
model_name, revision=revision, device=device, **kwargs
)
if max_seq_length:
# https://github.com/huggingface/sentence-transformers/issues/3575
self.model.max_seq_length = max_seq_length
Expand Down
2 changes: 2 additions & 0 deletions mteb/models/model_implementations/bmretriever_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
self,
model_name: str,
revision: str,
device: str | None = None,
instruction_template: str
| Callable[[str, PromptType | None], str]
| None = None,
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(

transformer = Transformer(
model_name,
device=device,
**kwargs,
)
pooling = Pooling(
Expand Down
11 changes: 9 additions & 2 deletions mteb/models/model_implementations/cde_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,17 @@ class CDEWrapper(SentenceTransformerEncoderWrapper):
"InstructionReranking",
)

def __init__(self, model: str, *args, **kwargs: Any) -> None:
def __init__(
self,
model: str,
revision: str | None = None,
device: str | None = None,
*args,
**kwargs: Any,
) -> None:
from transformers import AutoConfig

super().__init__(model, *args, **kwargs)
super().__init__(model, revision=revision, device=device, *args, **kwargs)
model_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
self.max_sentences = model_config.transductive_corpus_size

Expand Down
10 changes: 5 additions & 5 deletions mteb/models/model_implementations/e5_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self,
model_name: str,
revision: str,
device: str | None = None,
composed_prompt=None,
**kwargs: Any,
):
Expand All @@ -47,8 +48,7 @@ def __init__(
self.processor = LlavaNextProcessor.from_pretrained(
model_name, revision=revision
)
if "device" in kwargs:
self.device = kwargs.pop("device")
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = LlavaNextForConditionalGeneration.from_pretrained(
model_name, revision=revision, **kwargs
)
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_text_embeddings(
],
return_tensors="pt",
padding=True,
).to("cuda")
).to(self.device)
text_outputs = self.model(
**text_inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
Expand All @@ -111,7 +111,7 @@ def get_image_embeddings(
batch["image"],
return_tensors="pt",
padding=True,
).to("cuda")
).to(self.device)
image_outputs = self.model(
**img_inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
Expand Down Expand Up @@ -141,7 +141,7 @@ def encode(
]
inputs = self.processor(
prompts, batch["image"], return_tensors="pt", padding=True
).to("cuda")
).to(self.device)
outputs = self.model(
**inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
Expand Down
11 changes: 6 additions & 5 deletions mteb/models/model_implementations/jina_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def __init__(
self,
model: CrossEncoder | str,
revision: str | None = None,
device: str | None = None,
trust_remote_code: bool = True,
**kwargs: Any,
) -> None:
Expand All @@ -267,10 +268,7 @@ def __init__(
model, trust_remote_code=trust_remote_code, dtype="auto"
)

device = kwargs.get("device", None)
if device is None:
device = get_device_name()
logger.info(f"Use pytorch device: {device}")
device = device or get_device_name()

self.model.to(device)
self.model.eval()
Expand Down Expand Up @@ -320,6 +318,7 @@ def __init__(
self,
model: str,
revision: str,
device: str | None = None,
model_prompts: dict[str, str] | None = None,
**kwargs,
) -> None:
Expand All @@ -339,7 +338,9 @@ def __init__(
)
import flash_attn # noqa: F401

super().__init__(model, revision, model_prompts, **kwargs)
super().__init__(
model, revision, device=device, model_prompts=model_prompts, **kwargs
)

def encode(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def __init__(
self,
model_name: str,
revision: str,
device: str | None = None,
model_prompts: dict[str, str] | None = None,
**kwargs: Any,
):
from transformers import AutoModel, AutoTokenizer

self.model_name = model_name
device = kwargs.pop("device", None)
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModel.from_pretrained(
model_name, revision=revision, **kwargs
Expand Down
5 changes: 4 additions & 1 deletion mteb/models/model_implementations/nomic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
self,
model_name: str,
revision: str,
device: str | None = None,
model_prompts: dict[str, str] | None = None,
**kwargs: Any,
):
Expand All @@ -37,7 +38,9 @@ def __init__(
f"Current transformers version is {transformers.__version__} is lower than the required version"
f" {MODERN_BERT_TRANSFORMERS_MIN_VERSION}"
)
super().__init__(model_name, revision, model_prompts, **kwargs)
super().__init__(
model_name, revision, device=device, model_prompts=model_prompts, **kwargs
)

def to(self, device: torch.device) -> None:
self.model.to(device)
Expand Down
3 changes: 2 additions & 1 deletion mteb/models/model_implementations/nvidia_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def __init__(
self,
model_name: str,
revision: str,
device: str | None = None,
) -> None:
required_transformers_version = "4.51.0"
if Version(transformers_version) != Version(required_transformers_version):
Expand All @@ -355,7 +356,7 @@ def __init__(
self.attn_implementation = (
"flash_attention_2" if torch.cuda.is_available() else "eager"
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.task_prompts = TASK_PROMPTS
self.instruction_template = self._instruction_template

Expand Down
4 changes: 3 additions & 1 deletion mteb/models/model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _check_name(cls, v: str | None) -> str | None:
)
return v

def load_model(self, **kwargs: Any) -> MTEBModels:
def load_model(self, device: str | None = None, **kwargs: Any) -> MTEBModels:
"""Loads the model using the specified loader function."""
if self.loader is None:
raise NotImplementedError(
Expand All @@ -262,6 +262,8 @@ def load_model(self, **kwargs: Any) -> MTEBModels:
# Allow overwrites
_kwargs = self.loader_kwargs.copy()
_kwargs.update(kwargs)
if device is not None:
_kwargs["device"] = device

model: MTEBModels = self.loader(self.name, revision=self.revision, **_kwargs)
model.mteb_model_meta = self # type: ignore[misc]
Expand Down
18 changes: 16 additions & 2 deletions mteb/models/models_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,19 @@ class EncoderProtocol(Protocol):
In general the interface is kept aligned with sentence-transformers interface. In cases where exceptions occurs these are handled within MTEB.
"""

def __init__(self, model_name: str, revision: str | None, **kwargs: Any) -> None:
def __init__(
self,
model_name: str,
revision: str | None,
device: str | None = None,
**kwargs: Any,
) -> None:
"""The initialization function for the encoder. Used when calling it from the mteb run CLI.

Args:
model_name: Name of the model
revision: revision of the model
device: Device used to load the model
kwargs: Any additional kwargs
"""
...
Expand Down Expand Up @@ -181,12 +188,19 @@ class CrossEncoderProtocol(Protocol):
In general the interface is kept aligned with sentence-transformers interface. In cases where exceptions occurs these are handled within MTEB.
"""

def __init__(self, model_name: str, revision: str | None, **kwargs: Any) -> None:
def __init__(
self,
model_name: str,
revision: str | None,
device: str | None = None,
**kwargs: Any,
) -> None:
"""The initialization function for the encoder. Used when calling it from the mteb run CLI.

Args:
model_name: Name of the model
revision: revision of the model
device: Device used to load the model
kwargs: Any additional kwargs
"""
...
Expand Down
14 changes: 10 additions & 4 deletions mteb/models/sentence_transformer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@


def sentence_transformers_loader(
model_name: str, revision: str | None = None, **kwargs
model_name: str, revision: str | None = None, device: str | None = None, **kwargs
) -> SentenceTransformerEncoderWrapper:
"""Loads a SentenceTransformer model and wraps it in a SentenceTransformerEncoderWrapper.

Args:
model_name: The name of the SentenceTransformer model to load.
revision: The revision of the model to load.
device: The device used to load the model.
kwargs: Additional arguments to pass to the SentenceTransformer model.
"""
return SentenceTransformerEncoderWrapper(
model=model_name, revision=revision, **kwargs
model=model_name, revision=revision, device=device, **kwargs
)


Expand All @@ -49,6 +50,7 @@ def __init__(
self,
model: str | SentenceTransformer,
revision: str | None = None,
device: str | None = None,
model_prompts: dict[str, str] | None = None,
**kwargs,
) -> None:
Expand All @@ -57,6 +59,7 @@ def __init__(
Args:
model: The SentenceTransformer model to use. Can be a string (model name), a SentenceTransformer model, or a CrossEncoder model.
revision: The revision of the model to use.
device: The device used to load the model.
model_prompts: A dictionary mapping task names to prompt names.
First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt,
then to the composed prompt of task type + prompt type, then to the specific task type prompt,
Expand All @@ -66,7 +69,9 @@ def __init__(
from sentence_transformers import SentenceTransformer

if isinstance(model, str):
self.model = SentenceTransformer(model, revision=revision, **kwargs)
self.model = SentenceTransformer(
model, revision=revision, device=device, **kwargs
)
else:
self.model = model

Expand Down Expand Up @@ -266,14 +271,15 @@ def __init__(
self,
model: CrossEncoder | str,
revision: str | None = None,
device: str | None = None,
**kwargs,
) -> None:
from sentence_transformers import CrossEncoder

if isinstance(model, CrossEncoder):
self.model = model
elif isinstance(model, str):
self.model = CrossEncoder(model, revision=revision, **kwargs)
self.model = CrossEncoder(model, revision=revision, device=device, **kwargs)

self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model)

Expand Down