Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mteb/models/jina_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from sentence_transformers import __version__ as st_version

from mteb.encoder_interface import PromptType
from mteb.languages import PROGRAMMING_LANGS
from mteb.model_meta import ModelMeta
from mteb.models.sentence_transformer_wrapper import SentenceTransformerWrapper
from mteb.requires_package import requires_package
from mteb.languages import PROGRAMMING_LANGS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -234,8 +234,8 @@ def __init__(
)
requires_package(self, "peft", model, "pip install 'mteb[jina-v4]'")
requires_package(self, "torchvision", model, "pip install 'mteb[jina-v4]'")
import peft # noqa: F401
import flash_attn # noqa: F401
import peft # noqa: F401
import transformers # noqa: F401

super().__init__(model, revision, model_prompts, **kwargs)
Expand Down Expand Up @@ -284,8 +284,7 @@ def encode(
def get_programming_task_override(
task_name: str, current_task_name: str | None
) -> str | None:
"""
Check if task involves programming content and override with 'code' task if so.
"""Check if task involves programming content and override with 'code' task if so.

Args:
task_name: Original task name to check
Expand Down
2 changes: 2 additions & 0 deletions tests/test_benchmark/mock_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class MockSentenceTransformer(SentenceTransformer):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# by default, in SentenceTransformer, prompts are `{"query": "", "document": ""}`
self.prompts = {}

def encode(
self,
Expand Down
5 changes: 2 additions & 3 deletions tests/test_benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
import pytest
import torch
from sentence_transformers import SentenceTransformer

import mteb
import mteb.overview
Expand Down Expand Up @@ -114,7 +113,7 @@ def encode(self, sentences, prompt_name: str | None = None, **kwargs):
assert prompt_name == _task_name
return np.zeros((len(sentences), 10))

class EncoderWithoutInstructions(SentenceTransformer):
class EncoderWithoutInstructions(MockSentenceTransformer):
def encode(self, sentences, **kwargs):
assert kwargs["prompt_name"] is None
return super().encode(sentences, **kwargs)
Expand All @@ -138,7 +137,7 @@ def encode(self, sentences, **kwargs):
overwrite_results=True,
)
# Test that the task_name is not passed down to the encoder
model = EncoderWithoutInstructions("average_word_embeddings_levy_dependency")
model = EncoderWithoutInstructions()
assert model.prompts == {}, "The encoder should not have any prompts"
eval.run(model, output_folder=tmp_path.as_posix(), overwrite_results=True)

Expand Down