Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions mteb/models/bmretriever_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from __future__ import annotations

from functools import partial
from typing import Any, Callable

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Pooling, Transformer

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.instruct_wrapper import InstructSentenceTransformerWrapper


def instruction_template(
instruction: str, prompt_type: PromptType | None = None
) -> str:
return (
f"{instruction}\nQuery: "
if (prompt_type is None or prompt_type == PromptType.query) and instruction
else "Represent this passage\npassage: "
)


class BMRetrieverWrapper(InstructSentenceTransformerWrapper):
def __init__(
self,
model_name: str,
instruction_template: Callable[[str, PromptType | None], str] | None = None,
max_seq_length: int | None = None,
apply_instruction_to_passages: bool = True,
padding_side: str | None = None,
add_eos_token: bool = False,
prompts_dict: dict[str, str] | None = None,
**kwargs: Any,
):
self.model_name = model_name
self.instruction_template = instruction_template
self.apply_instruction_to_passages = apply_instruction_to_passages
self.add_eos_token = add_eos_token
self.prompts_dict = prompts_dict

transformer = Transformer(
model_name,
max_seq_length=max_seq_length,
**kwargs,
)
pooling = Pooling(
transformer.get_word_embedding_dimension(), pooling_mode="lasttoken"
)
self.model = SentenceTransformer(modules=[transformer, pooling])

if max_seq_length is not None:
self.model.max_seq_length = max_seq_length

if padding_side is not None:
self.model.tokenizer.padding_side = padding_side


# https://huggingface.co/datasets/BMRetriever/biomed_retrieval_dataset
BMRETRIEVER_TRAINING_DATA = {
"FEVER": ["train"],
"MSMARCO": ["train"],
"NQ": ["train"],
}

BMRetriever_410M = ModelMeta(
loader=partial(
BMRetrieverWrapper,
model_name="BMRetriever/BMRetriever-410M",
config_args={"revision": "e3569bfbcfe3a1bc48c142e11a7b0f38e86065a3"},
model_args={"torch_dtype": torch.float32},
instruction_template=instruction_template,
padding_side="left",
add_eos_token=True,
apply_instruction_to_passages=True,
),
name="BMRetriever/BMRetriever-410M",
languages=["eng-Latn"],
open_weights=True,
revision="e3569bfbcfe3a1bc48c142e11a7b0f38e86065a3",
release_date="2024-04-29",
embed_dim=1024,
n_parameters=353_822_720,
memory_usage_mb=1349,
max_tokens=2048,
license="mit",
reference="https://huggingface.co/BMRetriever/BMRetriever-410M",
similarity_fn_name="cosine",
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
public_training_code=None,
public_training_data=None,
training_datasets=BMRETRIEVER_TRAINING_DATA,
)

BMRetriever_1B = ModelMeta(
loader=partial(
BMRetrieverWrapper,
model_name="BMRetriever/BMRetriever-1B",
config_args={"revision": "1b758c5f4d3af48ef6035cc4088bdbcd7df43ca6"},
model_args={"torch_dtype": torch.float32},
instruction_template=instruction_template,
padding_side="left",
add_eos_token=True,
apply_instruction_to_passages=True,
),
name="BMRetriever/BMRetriever-1B",
languages=["eng-Latn"],
open_weights=True,
revision="1b758c5f4d3af48ef6035cc4088bdbcd7df43ca6",
release_date="2024-04-29",
embed_dim=2048,
n_parameters=908_759_040,
memory_usage_mb=3466,
max_tokens=2048,
license="mit",
reference="https://huggingface.co/BMRetriever/BMRetriever-1B",
similarity_fn_name="cosine",
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
public_training_code=None,
public_training_data=None,
training_datasets=BMRETRIEVER_TRAINING_DATA,
)

BMRetriever_2B = ModelMeta(
loader=partial(
BMRetrieverWrapper,
model_name="BMRetriever/BMRetriever-2B",
config_args={"revision": "718179afd57926369c347f46eee616db81084941"},
model_args={"torch_dtype": torch.float32},
instruction_template=instruction_template,
padding_side="left",
add_eos_token=True,
apply_instruction_to_passages=True,
),
name="BMRetriever/BMRetriever-2B",
languages=["eng-Latn"],
open_weights=True,
revision="718179afd57926369c347f46eee616db81084941",
release_date="2024-04-29",
embed_dim=2048,
n_parameters=2_506_172_416,
memory_usage_mb=9560,
max_tokens=8192,
license="mit",
reference="https://huggingface.co/BMRetriever/BMRetriever-2B",
similarity_fn_name="cosine",
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
public_training_code=None,
public_training_data=None,
training_datasets=BMRETRIEVER_TRAINING_DATA,
)

BMRetriever_7B = ModelMeta(
loader=partial(
BMRetrieverWrapper,
model_name="BMRetriever/BMRetriever-7B",
config_args={"revision": "e3569bfbcfe3a1bc48c142e11a7b0f38e86065a3"},
model_args={"torch_dtype": torch.float32},
instruction_template=instruction_template,
padding_side="left",
add_eos_token=True,
apply_instruction_to_passages=True,
),
name="BMRetriever/BMRetriever-7B",
languages=["eng-Latn"],
open_weights=True,
revision="e3569bfbcfe3a1bc48c142e11a7b0f38e86065a3",
release_date="2024-04-29",
embed_dim=4096,
n_parameters=7_110_660_096,
memory_usage_mb=27124,
max_tokens=32768,
license="mit",
reference="https://huggingface.co/BMRetriever/BMRetriever-7B",
similarity_fn_name="cosine",
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
public_training_code=None,
public_training_data=None,
training_datasets=BMRETRIEVER_TRAINING_DATA,
)
2 changes: 2 additions & 0 deletions mteb/models/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
blip2_models,
blip_models,
bm25,
bmretriever_models,
cadet_models,
cde_models,
clip_models,
Expand Down Expand Up @@ -121,6 +122,7 @@
blip2_models,
blip_models,
bm25,
bmretriever_models,
cadet_models,
clip_models,
codesage_models,
Expand Down
Loading