Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mteb/encoder_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class PromptType(str, Enum):
query = "query"
passage = "passage"
document = "document"


@runtime_checkable
Expand Down
6 changes: 3 additions & 3 deletions mteb/evaluation/evaluators/Image/Any2AnyRetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def search(
sub_corpus_embeddings = self.model.get_text_embeddings(
texts=corpus_texts,
task_name=task_name,
prompt_type=PromptType.passage,
prompt_type=PromptType.document,
**self.encode_kwargs,
)
else:
Expand All @@ -213,7 +213,7 @@ def search(
sub_corpus_embeddings = self.model.get_image_embeddings(
images=corpus_image_dataloader,
task_name=task_name,
prompt_type=PromptType.passage,
prompt_type=PromptType.document,
**self.encode_kwargs,
)
elif corpus_modality == "image,text":
Expand All @@ -222,7 +222,7 @@ def search(
texts=corpus_texts,
images=corpus_image_dataloader,
task_name=task_name,
prompt_type=PromptType.passage,
prompt_type=PromptType.document,
**self.encode_kwargs,
)
else:
Expand Down
8 changes: 4 additions & 4 deletions mteb/evaluation/evaluators/RerankingEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _encode_candidates_batched(
all_docs,
model,
task_name=self.task_name,
prompt_type=PromptType.passage,
prompt_type=PromptType.document,
**self.encode_kwargs,
)

Expand Down Expand Up @@ -245,7 +245,7 @@ def _encode_candidates_individual(
model.encode(
docs,
task_name=self.task_name,
prompt_type=PromptType.passage,
prompt_type=PromptType.document,
**self.encode_kwargs,
)
)
Expand Down Expand Up @@ -293,7 +293,7 @@ def _encode_candidates_miracl_batched(self, all_query_embs, model: Encoder):
model.encode(
all_docs,
task_name=self.task_name,
prompt_type=PromptType.passage,
prompt_type=PromptType.document,
**self.encode_kwargs,
)
)
Expand Down Expand Up @@ -345,7 +345,7 @@ def _encode_candidates_miracl_individual(self, model: Encoder):
model.encode(
docs,
task_name=self.task_name,
prompt_type=PromptType.passage,
prompt_type=PromptType.document,
**self.encode_kwargs,
)
)
Expand Down
6 changes: 3 additions & 3 deletions mteb/evaluation/evaluators/RetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def search(
sub_corpus_embeddings = self.model.encode(
corpus[corpus_start_idx:corpus_end_idx], # type: ignore
task_name=task_name,
prompt_type=PromptType.passage,
prompt_type=PromptType.document,
request_qid=request_qid,
**self.encode_kwargs,
)
Expand Down Expand Up @@ -385,7 +385,7 @@ def encode_corpus(
corpus: list[dict[str, str]],
task_name: str,
batch_size: int,
prompt_type: PromptType = PromptType.passage,
prompt_type: PromptType = PromptType.document,
request_qid: str | None = None,
**kwargs,
):
Expand Down Expand Up @@ -416,7 +416,7 @@ def encode(
prompt_type: PromptType | None = None,
**kwargs,
):
if prompt_type and prompt_type == PromptType.passage:
if prompt_type and prompt_type == PromptType.document:
return self.encode_corpus(
sentences, task_name, prompt_type=prompt_type, **kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion mteb/models/cadet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
revision="8056d118be37a566f20972a5f35cda815f6bc47e",
model_prompts={
"query": "query: ",
"passage": "passage: ",
"document": "passage: ",
},
),
name="manveertamber/cadet-embed-base-v1",
Expand Down
2 changes: 1 addition & 1 deletion mteb/models/cohere_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def encode(
"MultilabelClassification": "classification",
"Clustering": "clustering",
PromptType.query.value: "search_query",
PromptType.passage.value: "search_document",
PromptType.document.value: "search_document",
}

cohere_mult_3 = ModelMeta(
Expand Down
2 changes: 1 addition & 1 deletion mteb/models/e5_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@

model_prompts = {
PromptType.query.value: "query: ",
PromptType.passage.value: "passage: ",
PromptType.document.value: "passage: ",
}

E5_TRAINING_DATA = {
Expand Down
4 changes: 2 additions & 2 deletions mteb/models/gme_v_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def encode_corpus(self, corpus: list[dict[str, str]], **kwargs):
else doc["text"].strip()
for doc in corpus
]
embeddings = self.encode(sentences, prompt_type=PromptType.passage**kwargs)
embeddings = self.encode(sentences, prompt_type=PromptType.document**kwargs)
return embeddings

def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_fused_embeddings(
instruction=None,
**kwargs: Any,
):
if prompt_type == PromptType.passage:
if prompt_type == PromptType.document:
instruction = None
elif instruction is None:
instruction = self.get_instruction(task_name, prompt_type)
Expand Down
2 changes: 1 addition & 1 deletion mteb/models/google_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"Clustering": "CLUSTERING",
"STS": "SEMANTIC_SIMILARITY",
PromptType.query.value: "RETRIEVAL_QUERY",
PromptType.passage.value: "RETRIEVAL_DOCUMENT",
PromptType.document.value: "RETRIEVAL_DOCUMENT",
}

GECKO_TRAINING_DATA = {
Expand Down
7 changes: 5 additions & 2 deletions mteb/models/instruct_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,13 @@ def encode(
)

# to passage prompts won't be applied to passages
if not self.apply_instruction_to_passages and prompt_type == PromptType.passage:
if (
not self.apply_instruction_to_passages
and prompt_type == PromptType.document
):
instruction = None
logger.info(
f"No instruction used, because prompt type = {prompt_type.passage}"
f"No instruction used, because prompt type = {prompt_type.document}"
)

if instruction:
Expand Down
2 changes: 1 addition & 1 deletion mteb/models/jasper_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def encode(
instruction = self.get_task_instruction(task_name, prompt_type)

# to passage prompts won't be applied to passages
if prompt_type == PromptType.passage and task.metadata.category == "s2p":
if prompt_type == PromptType.document and task.metadata.category == "s2p":
instruction = None

embeddings = self.model.encode(
Expand Down
8 changes: 4 additions & 4 deletions mteb/models/jina_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def _resolve_task_parameters(
# Determine prompt name parameter
if jina_task_name and "query" in jina_task_name:
prompt_name_param = "query"
elif jina_task_name and "passage" in jina_task_name:
prompt_name_param = "passage"
elif jina_task_name and "document" in jina_task_name:
prompt_name_param = "document"
else:
prompt_name_param = "query" # default fallback

Expand Down Expand Up @@ -549,7 +549,7 @@ def get_programming_task_override(
trust_remote_code=True,
model_prompts={
"Retrieval-query": "retrieval.query",
"Retrieval-passage": "retrieval.passage",
"Retrieval-document": "retrieval.passage",
"STS": "text-matching",
"DocumentUnderstanding": "retrieval.query",
},
Expand Down Expand Up @@ -584,7 +584,7 @@ def get_programming_task_override(
trust_remote_code=True,
model_prompts={
"Retrieval-query": "retrieval.query",
"Retrieval-passage": "retrieval.passage",
"Retrieval-document": "retrieval.passage",
"Clustering": "separation",
"Classification": "classification",
"STS": "text-matching",
Expand Down
Loading