-
Notifications
You must be signed in to change notification settings - Fork 583
model: added Querit/Querit #3996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Samoed
merged 17 commits into
embeddings-benchmark:main
from
youngbeauty250:Querit_reranker
Jan 29, 2026
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
a1265bb
querit_models_add
youngbeauty250 381f556
Querit_Models_Change
youngbeauty250 d0cc412
Update
youngbeauty250 a5877c4
format revise
youngbeauty250 fac844b
add future
Samoed 62ca8d3
format revise
youngbeauty250 ac2dba9
format revise
youngbeauty250 6a5cda8
last format revison
youngbeauty250 82df179
last last revise
youngbeauty250 4faa4a0
last last last revison
youngbeauty250 7abc64c
revise
youngbeauty250 0036ef4
revise
youngbeauty250 359d4eb
change the instruction
youngbeauty250 ce59996
last revison
youngbeauty250 029b4e6
revise
youngbeauty250 4196786
revise
youngbeauty250 ef5aea2
revise
youngbeauty250 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import torch | ||
| from tqdm.auto import tqdm | ||
|
|
||
| from mteb.models.model_meta import ModelMeta | ||
|
|
||
| from .rerankers_custom import RerankerWrapper | ||
|
|
||
| if TYPE_CHECKING: | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from mteb.abstasks.task_metadata import TaskMetadata | ||
| from mteb.types import BatchedInput, PromptType | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class QueritWrapper(RerankerWrapper): | ||
| """ | ||
| Multi-GPU / multi-process reranker wrapper for mteb.mteb evaluation. | ||
| Supports flattening all query-passage pairs without explicit grouping. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: str, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(model_name, **kwargs) | ||
| from transformers import AutoModel, AutoTokenizer | ||
|
|
||
| if not self.device: | ||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| model_args = {} | ||
| if self.fp_options: | ||
| model_args["torch_dtype"] = self.fp_options | ||
| self.model = AutoModel.from_pretrained( | ||
| model_name, trust_remote_code=True, **model_args | ||
| ) | ||
| logger.info(f"Using model {model_name}") | ||
|
|
||
| self.model.to(self.device) | ||
| self.tokenizer = AutoTokenizer.from_pretrained( | ||
| model_name, trust_remote_code=True | ||
| ) | ||
| if "[CLS]" not in self.tokenizer.get_vocab(): | ||
| raise ValueError("Tokenizer missing required special token '[CLS]'") | ||
| self.cls_token_id = self.tokenizer.convert_tokens_to_ids("[CLS]") | ||
| self.pad_token_id = self.tokenizer.pad_token_id or 0 | ||
|
|
||
| self.max_length = ( | ||
| min(kwargs.get("max_length", 4096), self.tokenizer.model_max_length) - 1 | ||
| ) # sometimes it's a v large number/max int | ||
| logger.info(f"Using max_length of {self.max_length}, 1 token for [CLS]") | ||
| self.model.eval() | ||
|
|
||
| def process_inputs( | ||
| self, | ||
| pairs: list[str], | ||
| ) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Encode a batch of (query, document) pairs: | ||
| - Concatenate prompt + Query + Content | ||
| - Append [CLS] at the end | ||
| - Left-pad to max_length | ||
| - Generate custom attention mask based on block types | ||
| """ | ||
| # Construct input texts | ||
| enc = self.tokenizer( | ||
| pairs, | ||
| add_special_tokens=False, | ||
| truncation=True, | ||
| max_length=self.max_length, | ||
| padding=False, | ||
| ) | ||
|
|
||
| input_ids_list: list[list[int]] = [] | ||
| attn_mask_list: list[torch.Tensor] = [] | ||
|
|
||
| for ids in enc["input_ids"]: | ||
| # Append [CLS] token | ||
| ids = ids + [self.cls_token_id] | ||
| block_types = [1] * (len(ids) - 1) + [2] # content + CLS | ||
|
|
||
| # Pad or truncate | ||
| if len(ids) < self.max_length: | ||
| pad_len = self.max_length - len(ids) | ||
| ids = [self.pad_token_id] * pad_len + ids | ||
| block_types = [0] * pad_len + block_types | ||
| else: | ||
| ids = ids[-self.max_length :] | ||
| block_types = block_types[-self.max_length :] | ||
|
|
||
| attn = self.compute_mask_content_cls(block_types) | ||
| input_ids_list.append(ids) | ||
| attn_mask_list.append(attn) | ||
|
|
||
| input_ids = torch.tensor(input_ids_list, dtype=torch.long, device=self.device) | ||
| attention_mask = torch.stack(attn_mask_list, dim=0).to(self.device) | ||
|
|
||
| return {"input_ids": input_ids, "attention_mask": attention_mask} | ||
|
|
||
| @torch.inference_mode() | ||
| def predict( | ||
| self, | ||
| inputs1: DataLoader[BatchedInput], | ||
| inputs2: DataLoader[BatchedInput], | ||
| *, | ||
| task_metadata: TaskMetadata, | ||
| hf_split: str, | ||
| hf_subset: str, | ||
| prompt_type: PromptType | None = None, | ||
| **kwargs: Any, | ||
| ) -> list[float]: | ||
| """ | ||
| Predict relevance scores for query-passage pairs. | ||
| Supports both single-process and multi-process/multi-GPU modes. | ||
| """ | ||
| # Flatten all pairs from mteb.mteb DataLoaders | ||
| queries = [text for batch in inputs1 for text in batch["text"]] | ||
| passages = [text for batch in inputs2 for text in batch["text"]] | ||
|
|
||
| instructions = None | ||
| if "instruction" in inputs2.dataset.features: | ||
| instructions = [text for batch in inputs1 for text in batch["instruction"]] | ||
|
|
||
| num_pairs = len(queries) | ||
| if num_pairs == 0: | ||
| return [] | ||
| final_scores: list[float] = [] | ||
|
|
||
| batch_size = kwargs.get("batch_size", self.batch_size) | ||
| with tqdm(total=num_pairs, desc="Scoring", ncols=100) as pbar: | ||
| for start in range(0, num_pairs, batch_size): | ||
| end = min(start + batch_size, num_pairs) | ||
| batch_q = queries[start:end] | ||
| batch_d = passages[start:end] | ||
|
|
||
| batch_instructions = ( | ||
| instructions[start:end] | ||
| if instructions is not None | ||
| else [None] * len(batch_q) | ||
| ) | ||
| pairs = [ | ||
| self.format_instruction(instr, query, doc) | ||
| for instr, query, doc in zip(batch_instructions, batch_q, batch_d) | ||
| ] | ||
| enc = self.process_inputs(pairs) | ||
| out = self.model(**enc) | ||
| scores = out["score"].squeeze(-1).detach().float().cpu().tolist() | ||
|
|
||
| if not isinstance(scores, list): | ||
| scores = [scores] | ||
|
|
||
| final_scores.extend(scores) | ||
| pbar.update(len(scores)) | ||
|
|
||
| return final_scores | ||
|
|
||
| @staticmethod | ||
| def format_instruction(instruction: str | None, query: str, doc: str) -> str: | ||
| if instruction is None: | ||
| output = f"Judge whether the Content meets the requirements based on the Query. Query: {query}; Content: {doc}" | ||
Samoed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| output = f"{instruction} Query: {query}; Content: {doc}" | ||
| return output | ||
|
|
||
| @staticmethod | ||
| def compute_mask_content_cls(block_types: list[int]) -> torch.Tensor: | ||
| """ | ||
| Create custom attention mask based on token block types: | ||
| - 0: padding → ignored | ||
| - 1: content → causal attention to previous content only | ||
| - 2: [CLS] → causal attention to all non-padding tokens | ||
|
|
||
| Args: | ||
| block_types: List of token types for one sequence | ||
|
|
||
| Returns: | ||
| [1, seq_len, seq_len] boolean attention mask (True = allowed to attend) | ||
| """ | ||
| pos = torch.tensor(block_types, dtype=torch.long) | ||
| n = pos.shape[0] | ||
| if n == 0: | ||
| return torch.empty((0, 0), dtype=torch.bool, device=pos.device) | ||
|
|
||
| row_types = pos.view(n, 1) | ||
| col_types = pos.view(1, n) | ||
|
|
||
| row_idx = torch.arange(n, device=pos.device).view(n, 1) | ||
| col_idx = torch.arange(n, device=pos.device).view(1, n) | ||
| causal_mask = col_idx <= row_idx | ||
|
|
||
| # Content tokens only attend to previous content | ||
| mask_content = (row_types == 1) & (col_types == 1) & causal_mask | ||
|
|
||
| # [CLS] attends to all non-pad tokens (causal) | ||
| mask_cls = (row_types == 2) & (col_types != 0) & causal_mask | ||
|
|
||
| type_mask = mask_content | mask_cls | ||
| return type_mask.unsqueeze(0) | ||
|
|
||
|
|
||
| querit_reranker_training_data = { | ||
| "MIRACLRanking", # https://huggingface.co/datasets/mteb/MIRACLReranking | ||
| "MrTidyRetrieval", # https://huggingface.co/datasets/mteb/mrtidy | ||
| "ruri-v3-dataset-reranker", # https://huggingface.co/datasets/cl-nagoya/ruri-v3-dataset-reranker | ||
| "MultiLongDocReranking", # https://huggingface.co/datasets/Shitao/MLDR | ||
| "MindSmallReranking", # https://huggingface.co/datasets/mteb/MindSmallReranking | ||
| "MSMARCO", # https://huggingface.co/datasets/mteb/msmarco | ||
| "CQADupStack", # https://huggingface.co/datasets/mteb/cqadupstack-* | ||
| "AskUbuntuDupQuestions", # https://github.com/taolei87/askubuntu & The corpus and queries that overlap with mteb/askubuntudupquestions-reranking have been removed. | ||
| "T2Reranking", # https://huggingface.co/datasets/THUIR/T2Ranking & The corpus and queries that overlap with mteb/T2Reranking have been removed. | ||
| } | ||
|
|
||
| model_meta = ModelMeta( | ||
| loader=QueritWrapper, | ||
| loader_kwargs={ | ||
| "fp_options": "bfloat16", | ||
| }, | ||
| name="Querit/Querit", | ||
youngbeauty250 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model_type=["cross-encoder"], | ||
| languages=["eng-Latn"], | ||
| open_weights=True, | ||
| revision="5ad2649cc4defb7e1361262260e9a781f14b08bc", | ||
| release_date="2026-01-24", | ||
| n_parameters=4919636992, | ||
| n_embedding_parameters=131907584, | ||
| embed_dim=1024, | ||
| memory_usage_mb=9383.0, | ||
| max_tokens=4096, | ||
| reference="https://huggingface.co/Querit/Querit", | ||
| similarity_fn_name=None, | ||
| training_datasets=querit_reranker_training_data, | ||
| license="apache-2.0", | ||
| framework=["PyTorch"], | ||
| use_instructions=None, | ||
| public_training_code=None, | ||
| public_training_data=None, | ||
| citation=None, | ||
Samoed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.