Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 33 additions & 116 deletions tests/mock_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""This implements minimal viable mock tasks for testing the benchmarking framework."""

from __future__ import annotations

from typing import TYPE_CHECKING

import datasets
import numpy as np
from datasets import Dataset, DatasetDict
from sklearn.linear_model import LogisticRegression

Expand All @@ -26,6 +31,9 @@
AbsTaskZeroShotClassification,
)

if TYPE_CHECKING:
from PIL.Image import Image

general_args = {
"description": "a mock task for testing",
"reference": "https://github.com/embeddings-benchmark/mteb",
Expand Down Expand Up @@ -111,6 +119,13 @@ def instruction_retrieval_datasplit() -> RetrievalSplitData:
return base_ds


def create_mock_images(np_rng: np.random.Generator, n: int = 2) -> list[Image]:
from PIL import Image

images = [np_rng.integers(0, 255, (100, 100, 3)) for _ in range(n)]
return [Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images]


class MockClassificationTask(AbsTaskClassification):
classifier = LogisticRegression(n_jobs=1, max_iter=10)

Expand Down Expand Up @@ -1133,17 +1148,8 @@ class MockPairImageClassificationTask(AbsTaskPairClassification):
input2_column_name = "image2"

def load_data(self) -> None:
from PIL import Image

images1 = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images1 = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images1
]

images2 = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images2 = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images2
]
images1 = create_mock_images(self.np_rng)
images2 = create_mock_images(self.np_rng)

labels = [1, 0]

Expand Down Expand Up @@ -2701,12 +2707,7 @@ class MockMultiChoiceTask(AbsTaskRetrieval):
metadata.category = "it2i"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
retrieval_split_data = RetrievalSplitData(
queries=Dataset.from_dict(
{
Expand All @@ -2716,14 +2717,12 @@ def load_data(self) -> None:
"This is a positive sentence",
"This is another positive sentence",
],
"modality": ["image,text" for _ in range(2)],
}
),
corpus=Dataset.from_dict(
{
"id": ["d1", "d2"],
"image": [images[i] for i in range(2)],
"modality": ["image" for _ in range(2)],
}
),
relevant_docs={
Expand Down Expand Up @@ -2885,12 +2884,7 @@ class MockMultilingualMultiChoiceTask(AbsTaskRetrieval):
metadata.category = "it2i"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)

split_data = RetrievalSplitData(
queries=Dataset.from_dict(
Expand All @@ -2901,14 +2895,12 @@ def load_data(self) -> None:
"This is a positive sentence",
"This is another positive sentence",
],
"modality": ["image,text" for _ in range(2)],
}
),
corpus=Dataset.from_dict(
{
"id": ["d1", "d2"],
"image": [images[i] for i in range(2)],
"modality": ["image" for _ in range(2)],
}
),
relevant_docs={
Expand Down Expand Up @@ -2976,19 +2968,13 @@ class MockAny2AnyRetrievalI2TTask(AbsTaskRetrieval):
metadata.category = "i2t"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)

retrieval_split_data = RetrievalSplitData(
queries=Dataset.from_dict(
{
"id": [f"q{i}" for i in range(2)],
"image": [images[i] for i in range(2)],
"modality": ["image" for _ in range(2)],
}
),
corpus=Dataset.from_dict(
Expand All @@ -2998,7 +2984,6 @@ def load_data(self) -> None:
"This is a positive sentence",
"This is another positive sentence",
],
"modality": ["text" for _ in range(2)],
}
),
relevant_docs={
Expand Down Expand Up @@ -3055,12 +3040,7 @@ class MockAny2AnyRetrievalT2ITask(AbsTaskRetrieval):
metadata.category = "t2i"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)

retrieval_split_data = RetrievalSplitData(
queries=Dataset.from_dict(
Expand All @@ -3070,14 +3050,12 @@ def load_data(self) -> None:
"This is a positive sentence",
"This is another positive sentence",
],
"modality": ["text" for _ in range(2)],
}
),
corpus=Dataset.from_dict(
{
"id": ["d1", "d2"],
"image": [images[i] for i in range(2)],
"modality": ["image" for _ in range(2)],
}
),
relevant_docs={
Expand Down Expand Up @@ -3149,12 +3127,7 @@ class MockImageClassificationTask(AbsTaskClassification):
input_column_name = "image"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
labels = [1, 0]

self.dataset = DatasetDict(
Expand Down Expand Up @@ -3324,12 +3297,7 @@ class MockMultilingualImageClassificationTask(AbsTaskClassification):
input_column_name = "image"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
labels = [1, 0]
data = {
"test": Dataset.from_dict(
Expand Down Expand Up @@ -3390,12 +3358,7 @@ class MockImageClusteringTask(AbsTaskClusteringLegacy):
label_column_name = "label"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
labels = [1, 0]

self.dataset = DatasetDict(
Expand Down Expand Up @@ -3448,12 +3411,7 @@ class MockImageClusteringFastTask(AbsTaskClustering):
max_document_to_embed = 2

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
labels = [1, 0]

self.dataset = DatasetDict(
Expand Down Expand Up @@ -3538,12 +3496,7 @@ class MockImageMultilabelClassificationTask(AbsTaskMultilabelClassification):
input_column_name = "image"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
labels = [["0", "3"], ["1", "2"]]

self.dataset = DatasetDict(
Expand Down Expand Up @@ -3743,12 +3696,7 @@ class MockMultilingualImageMultilabelClassificationTask(
input_column_name = "image"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
labels = [["0", "3"], ["1", "2"]]

data = {
Expand Down Expand Up @@ -3808,12 +3756,7 @@ class MockImageTextPairClassificationTask(AbsTaskImageTextPairClassification):
metadata.category = "i2t"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
texts = ["This is a test sentence", "This is another test sentence"]

self.dataset = DatasetDict(
Expand Down Expand Up @@ -3906,12 +3849,7 @@ class MockMultilingualImageTextPairClassificationTask(
metadata.eval_langs = multilingual_eval_langs

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
texts = ["This is a test sentence", "This is another test sentence"]
data = {
"test": Dataset.from_dict(
Expand Down Expand Up @@ -3971,12 +3909,7 @@ class MockVisualSTSTask(AbsTaskSTS):
metadata.category = "i2i"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
scores = [0.5, 0.5]

self.dataset = DatasetDict(
Expand Down Expand Up @@ -4037,12 +3970,7 @@ class MockZeroShotClassificationTask(AbsTaskZeroShotClassification):
metadata.category = "i2t"

def load_data(self) -> None:
from PIL import Image

images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
images = [
Image.fromarray(image.astype("uint8")).convert("RGBA") for image in images
]
images = create_mock_images(self.np_rng)
labels = ["label1", "label2"]

self.dataset = DatasetDict(
Expand Down Expand Up @@ -4228,19 +4156,8 @@ class MockImageRegressionTask(AbsTaskRegression):
input_column_name = "image"

def load_data(self, **kwargs):
from PIL import Image

train_images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
train_images = [
Image.fromarray(image.astype("uint8")).convert("RGBA")
for image in train_images
]

test_images = [self.np_rng.integers(0, 255, (100, 100, 3)) for _ in range(2)]
test_images = [
Image.fromarray(image.astype("uint8")).convert("RGBA")
for image in test_images
]
train_images = create_mock_images(self.np_rng)
test_images = create_mock_images(self.np_rng)

train_values = [1.0, 0.0]
test_values = [1.0, 0.0]
Expand Down