Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mteb/models/sentence_transformer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def __init__(
):
try:
model_prompts = self.validate_task_to_prompt_name(self.model.prompts)
except ValueError:
except KeyError:
model_prompts = None
logger.warning(
"Model prompts are not in the expected format. Ignoring them."
)
elif model_prompts is not None and hasattr(self.model, "prompts"):
logger.info(f"Model prompts will be overwritten with {model_prompts}")
self.model.prompts = model_prompts
Expand Down
12 changes: 6 additions & 6 deletions mteb/models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ def validate_task_to_prompt_name(
if "-" in task_name:
task_name, prompt_type = task_name.split("-")
if prompt_type not in prompt_types:
raise ValueError(
f"Prompt type {prompt_type} is not valid. Valid prompt types are {prompt_types}"
)
msg = f"Prompt type {prompt_type} is not valid. Valid prompt types are {prompt_types}"
logger.warning(msg)
raise KeyError(msg)
if task_name not in task_types and task_name not in prompt_types:
task = mteb.get_task(task_name=task_name)
if not task:
raise ValueError(
f"Task name {task_name} is not valid. Valid task names are task types [{task_types}], prompt types [{prompt_types}] and task names"
)
msg = f"Task name {task_name} is not valid. Valid task names are task types [{task_types}], prompt types [{prompt_types}] and task names"
logger.warning(msg)
raise KeyError(msg)
return task_to_prompt_name

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/test_reproducible_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ def test_validate_task_to_prompt_name_fail():
{"task_name": "prompt_name", "task_name-query": "prompt_name"}
)

with pytest.raises(ValueError):
with pytest.raises(KeyError):
Wrapper.validate_task_to_prompt_name({"task_name-task_name": "prompt_name"})
Loading