diff --git a/mteb/models/__init__.py b/mteb/models/__init__.py index c68e4f5a5a..1277235373 100644 --- a/mteb/models/__init__.py +++ b/mteb/models/__init__.py @@ -20,6 +20,8 @@ mxbai_models, nomic_models, openai_models, + promptriever_models, + repllama_models, ru_sentence_models, salesforce_models, sentence_transformers_models, @@ -133,6 +135,8 @@ def model_meta_from_sentence_transformers(model: SentenceTransformer) -> ModelMe mxbai_models, nomic_models, openai_models, + promptriever_models, + repllama_models, ru_sentence_models, salesforce_models, sentence_transformers_models, diff --git a/mteb/models/promptriever_models.py b/mteb/models/promptriever_models.py new file mode 100644 index 0000000000..97a6d2ed6c --- /dev/null +++ b/mteb/models/promptriever_models.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import logging +from typing import Any, Callable, Literal + +import numpy as np +import torch + +from mteb.encoder_interface import Encoder +from mteb.model_meta import ModelMeta + +from .repllama_models import RepLLaMAWrapper + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + +EncodeTypes = Literal["query", "passage"] + + +class PromptrieverWrapper(RepLLaMAWrapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray: + queries = [f"query: {query}" for query in queries] + if "instruction" in kwargs: + end_punct_list = [ + "?" if query.strip()[-1] not in ["?", ".", "!"] else "" + for query in queries + ] + queries = [ + f"{query}{end_punct_list[i]} {kwargs['instruction']}" + for i, query in enumerate(queries) + ] + return self.encode(queries, **kwargs) + + +def _loader(wrapper: type[PromptrieverWrapper], **kwargs) -> Callable[..., Encoder]: + _kwargs = kwargs + + def loader_inner(**kwargs: Any) -> Encoder: + return wrapper(**_kwargs, **kwargs) + + return loader_inner + + +promptriever_llama2 = ModelMeta( + loader=_loader( + RepLLaMAWrapper, + base_model_name_or_path="meta-llama/Llama-2-7b-hf", + peft_model_name_or_path="samaya-ai/promptriever-llama2-7b-v1", + device_map="auto", + torch_dtype=torch.bfloat16, + ), + name="samaya-ai/promptriever-llama2-7b-v1", + languages=["eng_Latn"], + open_source=True, + revision="01c7f73d771dfac7d292323805ebc428287df4f9-30b14e3813c0fa45facfd01a594580c3fe5ecf23", # base-peft revision + release_date="2024-09-15", +) + +promptriever_llama3 = ModelMeta( + loader=_loader( + RepLLaMAWrapper, + base_model_name_or_path="meta-llama/Meta-Llama-3.1-8B", + peft_model_name_or_path="samaya-ai/promptriever-llama3.1-8b-v1", + device_map="auto", + torch_dtype=torch.bfloat16, + ), + name="samaya-ai/promptriever-llama3.1-8b-v1", + languages=["eng_Latn"], + open_source=True, + revision="48d6d0fc4e02fb1269b36940650a1b7233035cbb-2ead22cfb1b0e0c519c371c63c2ab90ffc511b8a", # base-peft revision + release_date="2024-09-15", +) + + +promptriever_llama3_instruct = ModelMeta( + loader=_loader( + RepLLaMAWrapper, + base_model_name_or_path="meta-llama/Meta-Llama-3.1-8B-Instruct", + peft_model_name_or_path="samaya-ai/promptriever-llama3.1-8b-instruct-v1", + device_map="auto", + torch_dtype=torch.bfloat16, + ), + name="samaya-ai/promptriever-llama3.1-8b-instruct-v1", + languages=["eng_Latn"], + open_source=True, + revision="5206a32e0bd3067aef1ce90f5528ade7d866253f-8b677258615625122c2eb7329292b8c402612c21", # base-peft revision + release_date="2024-09-15", +) + +promptriever_mistral_v1 = ModelMeta( + loader=_loader( + RepLLaMAWrapper, + base_model_name_or_path="mistralai/Mistral-7B-v0.1", + peft_model_name_or_path="samaya-ai/promptriever-mistral-v0.1-7b-v1", + device_map="auto", + torch_dtype=torch.bfloat16, + ), + name="samaya-ai/promptriever-mistral-v0.1-7b-v1", + languages=["eng_Latn"], + open_source=True, + revision="7231864981174d9bee8c7687c24c8344414eae6b-876d63e49b6115ecb6839893a56298fadee7e8f5", # base-peft revision + release_date="2024-09-15", +) diff --git a/mteb/models/repllama_models.py b/mteb/models/repllama_models.py new file mode 100644 index 0000000000..b5eebb86e5 --- /dev/null +++ b/mteb/models/repllama_models.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import logging +from typing import Any, Callable, Literal + +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +from transformers import AutoModel, AutoTokenizer + +from mteb.encoder_interface import Encoder +from mteb.model_meta import ModelMeta +from mteb.models.text_formatting_utils import corpus_to_texts + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + +EncodeTypes = Literal["query", "passage"] + + +class RepLLaMAWrapper: + def __init__(self, *args, **kwargs): + try: + from peft import PeftModel + except ImportError: + raise ImportError( + "To use the RepLLaMA based models `peft` is required. Please install it with `pip install 'mteb[peft]'`." + ) + + self.base_model = AutoModel.from_pretrained( + kwargs["base_model_name_or_path"], + torch_dtype=kwargs["torch_dtype"], + device_map=kwargs["device_map"], + ) + self.model = PeftModel.from_pretrained( + self.base_model, kwargs["peft_model_name_or_path"] + ) + self.model = self.model.merge_and_unload() + + self.tokenizer = AutoTokenizer.from_pretrained( + kwargs["base_model_name_or_path"] + ) + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "right" + # set the max_length for the evals as they did, although the model can handle longer + self.model.config.max_length = 512 + self.tokenizer.model_max_length = 512 + + def create_batch_dict(self, tokenizer, input_texts): + max_length = self.model.config.max_length + batch_dict = tokenizer( + input_texts, + max_length=max_length - 1, + return_token_type_ids=False, + return_attention_mask=False, + padding=False, + truncation=True, + ) + batch_dict["input_ids"] = [ + input_ids + [tokenizer.eos_token_id] + for input_ids in batch_dict["input_ids"] + ] + return tokenizer.pad( + batch_dict, + padding=True, + pad_to_multiple_of=8, + return_attention_mask=True, + return_tensors="pt", + ) + + def encode( + self, + sentences: list[str], + *, + prompt_name: str = None, + **kwargs: Any, # noqa + ) -> np.ndarray: + batch_size = 16 if "batch_size" not in kwargs else kwargs.pop("batch_size") + all_embeddings = [] + for i in tqdm.tqdm(range(0, len(sentences), batch_size)): + batch_texts = sentences[i : i + batch_size] + + batch_dict = self.create_batch_dict(self.tokenizer, batch_texts) + batch_dict = { + key: value.to(self.model.device) for key, value in batch_dict.items() + } + + with torch.cuda.amp.autocast(): + with torch.no_grad(): + outputs = self.model(**batch_dict) + last_hidden_state = outputs.last_hidden_state + sequence_lengths = batch_dict["attention_mask"].sum(dim=1) - 1 + batch_size = last_hidden_state.shape[0] + reps = last_hidden_state[ + torch.arange(batch_size, device=last_hidden_state.device), + sequence_lengths, + ] + embeddings = F.normalize(reps, p=2, dim=-1) + all_embeddings.append(embeddings.cpu().numpy()) + + return np.concatenate(all_embeddings, axis=0) + + def encode_corpus( + self, + corpus: list[dict[str, str]] | dict[str, list[str]] | list[str], + prompt_name: str = None, + **kwargs: Any, + ) -> np.ndarray: + sentences = corpus_to_texts(corpus, sep=" ") + if "request_qid" in kwargs: + kwargs.pop("request_qid") + # NOTE: two spaces after the colon + sentences = [f"passage: {sentence}".strip() for sentence in sentences] + print(f"Encoding corpus of length {len(sentences)}") + print(f"First sentence: {sentences[0]}") + return self.encode(sentences, **kwargs) + + def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray: + # NOTE: two spaces after the colon + queries = [f"query: {query.strip()}".strip() for query in queries] + print(f"Encoding queries of length {len(queries)}") + print(queries[0]) + return self.encode(queries, **kwargs) + + +def _loader(wrapper: type[RepLLaMAWrapper], **kwargs) -> Callable[..., Encoder]: + _kwargs = kwargs + + def loader_inner(**kwargs: Any) -> Encoder: + return wrapper(**_kwargs, **kwargs) + + return loader_inner + + +repllama_llama2_original = ModelMeta( + loader=_loader( + RepLLaMAWrapper, + base_model_name_or_path="meta-llama/Llama-2-7b-hf", + peft_model_name_or_path="castorini/repllama-v1-7b-lora-passage", + device_map="auto", + torch_dtype=torch.bfloat16, + ), + name="castorini/repllama-v1-7b-lora-passage", + languages=["eng_Latn"], + open_source=True, + revision="01c7f73d771dfac7d292323805ebc428287df4f9-6097554dfe6e7d93e92f55010b678bcca1e233a8", # base-peft revision + release_date="2023-10-11", +) + + +repllama_llama2_reproduced = ModelMeta( + loader=_loader( + RepLLaMAWrapper, + base_model_name_or_path="meta-llama/Llama-2-7b-hf", + peft_model_name_or_path="samaya-ai/RepLLaMA-reproduced", + device_map="auto", + torch_dtype=torch.bfloat16, + ), + name="samaya-ai/RepLLaMA-reproduced", + languages=["eng_Latn"], + open_source=True, + revision="01c7f73d771dfac7d292323805ebc428287df4f9-ad5c1d0938a1e02954bcafb4d811ba2f34052e71", # base-peft revision + release_date="2024-09-15", +) + + +## Debug code +# import mteb +# model = mteb.get_model("samaya-ai/RepLLaMA-reproduced") +# tasks = mteb.get_tasks(tasks=["SciFact"], languages=["eng"]) +# evaluation = mteb.MTEB(tasks=tasks) +# evaluation.run(model) diff --git a/pyproject.toml b/pyproject.toml index c2b92ed8c6..3ab0439e05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,8 @@ dev = ["ruff==0.6.4", # locked so we don't get PRs which fail only due to a lint "pytest", "pytest-xdist", "pytest-coverage"] codecarbon = ["codecarbon"] speedtask = ["GPUtil>=1.4.0", "psutil>=5.9.8"] +peft = ["peft>=0.11.0"] + [tool.coverage.report]