Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions mteb/abstasks/Image/AbsTaskImageClassification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from collections import defaultdict
from collections import Counter, defaultdict
from typing import Any

import numpy as np
Expand All @@ -16,12 +16,45 @@
ImagelogRegClassificationEvaluator,
)
from ..AbsTask import AbsTask, ScoresDict
from ..TaskMetadata import DescriptiveStatistics

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)


class ImageClassificationDescriptiveStatistics(DescriptiveStatistics):
"""Descriptive statistics for ImageClassification

Attributes:
num_samples: number of samples in the dataset.

min_image_width: Minimum width of images
average_image_width: Average width of images
max_image_width: Maximum width of images

min_image_height: Minimum height of images
average_image_height: Average height of images
max_image_height: Maximum height of images

unique_labels: Number of unique labels
labels: dict of label frequencies
"""

num_samples: int

min_image_width: float
average_image_width: float
max_image_width: float

min_image_height: float
average_image_height: float
max_image_height: float

unique_num_labels: int
labels: dict[str, dict[str, int]]


class AbsTaskImageClassification(AbsTask):
"""Abstract class for kNN classification tasks
The similarity is computed between pairs and the results are ranked.
Expand Down Expand Up @@ -73,8 +106,43 @@ def _add_main_score(self, scores: dict[HFSubset, ScoresDict]) -> None:

def _calculate_metrics_from_split(
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
):
pass
) -> ImageClassificationDescriptiveStatistics:
if hf_subset:
imgs = self.dataset[hf_subset][split][self.image_column_name]
labels = self.dataset[hf_subset][split][self.label_column_name]
elif compute_overall:
imgs = []
labels = []
for hf_subset in self.metadata.eval_langs:
imgs.extend(self.dataset[hf_subset][split][self.image_column_name])
labels.extend(self.dataset[hf_subset][split][self.label_column_name])
else:
imgs = self.dataset[split][self.image_column_name]
labels = self.dataset[split][self.label_column_name]

num_samples = len(labels)
unique_num_labels = len(set(labels))
label_count = Counter(labels)

img_widths, img_heights = [], []
for img in imgs:
width, height = img.size
img_heights.append(height)
img_widths.append(width)

return ImageClassificationDescriptiveStatistics(
num_samples=num_samples,
unique_num_labels=unique_num_labels,
min_image_width=min(img_widths),
average_image_width=sum(img_widths) / len(img_widths),
max_image_width=max(img_widths),
min_image_height=min(img_heights),
average_image_height=sum(img_heights) / len(img_heights),
max_image_height=max(img_heights),
labels={
str(label): {"count": count} for label, count in label_count.items()
},
)

def evaluate(
self,
Expand Down
32 changes: 20 additions & 12 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,20 @@
"rendered",
"multiple",
]
TASK_TYPE = Literal[

MIEB_TASK_TYPE = (
"Any2AnyMultiChoice",
"Any2AnyRetrieval",
"Any2TextMutipleChoice",
"ImageClustering",
"ImageClassification",
"ImageMultilabelClassification",
"ImageTextPairClassification",
"VisualSTS",
"ZeroShotClassification",
)

TASK_TYPE = (
"BitextMining",
"Classification",
"MultilabelClassification",
Expand All @@ -112,16 +125,9 @@
"Summarization",
"InstructionRetrieval",
"Speed",
"Any2AnyMultiChoice",
"Any2AnyRetrieval",
"Any2TextMutipleChoice",
"ImageClustering",
"ImageClassification",
"ImageMultilabelClassification",
"ImageTextPairClassification",
"VisualSTS",
"ZeroShotClassification",
]
) + MIEB_TASK_TYPE

TASK_TYPE = Literal[TASK_TYPE]


TASK_CATEGORY = Literal[
Expand Down Expand Up @@ -455,9 +461,11 @@ def descriptive_stats(self) -> dict[str, DescriptiveStatistics] | None:
def descriptive_stat_path(self) -> Path:
"""Return the path to the descriptive statistics file."""
descriptive_stat_base_dir = Path(__file__).parent.parent / "descriptive_stats"
if self.type in MIEB_TASK_TYPE:
descriptive_stat_base_dir = descriptive_stat_base_dir / "Image"
task_type_dir = descriptive_stat_base_dir / self.type
if not descriptive_stat_base_dir.exists():
descriptive_stat_base_dir.mkdir()
task_type_dir = descriptive_stat_base_dir / self.type
if not task_type_dir.exists():
task_type_dir.mkdir()
return task_type_dir / f"{self.name}.json"
Expand Down
44 changes: 44 additions & 0 deletions mteb/descriptive_stats/Image/ImageClassification/MNIST.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"test": {
"num_samples": 10000,
"unique_num_labels": 10,
"min_image_width": 28,
"average_image_width": 28.0,
"max_image_width": 28,
"min_image_height": 28,
"average_image_height": 28.0,
"max_image_height": 28,
"labels": {
"7": {
"count": 1028
},
"2": {
"count": 1032
},
"1": {
"count": 1135
},
"0": {
"count": 980
},
"4": {
"count": 982
},
"9": {
"count": 1009
},
"5": {
"count": 892
},
"6": {
"count": 958
},
"3": {
"count": 1010
},
"8": {
"count": 974
}
}
}
}
Loading