Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion mteb/abstasks/_statistics_calculation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import hashlib
from collections import Counter
from collections import Counter, defaultdict
from collections.abc import Mapping
from typing import TYPE_CHECKING, cast

from mteb.types import TopRankedDocumentsType
from mteb.types._encoder_io import AudioInputItem
from mteb.types.statistics import (
AudioStatistics,
ImageStatistics,
LabelStatistics,
RelevantDocsStatistics,
Expand Down Expand Up @@ -73,6 +75,43 @@ def calculate_image_statistics(images: list[Image.Image]) -> ImageStatistics:
)


def calculate_audio_statistics(audios: list[AudioInputItem]) -> AudioStatistics:
"""Calculate descriptive statistics for a list of audio clips.

Args:
audios: List of audio clips to analyze. Each audio clip should be a dictionary with 'array' and 'sampling_rate' keys.

Returns:
A dictionary containing the descriptive statistics.
"""
audio_lengths = []
sampling_rates: dict[int, int] = defaultdict(int)
unique_audios = set()

for audio in audios:
array = audio["array"]
sampling_rate = audio["sampling_rate"]
length_in_seconds = len(array) / sampling_rate
audio_lengths.append(length_in_seconds)
sampling_rates[sampling_rate] += 1

audio_bytes = array.tobytes()
audio_hash = hashlib.md5(audio_bytes).hexdigest()
unique_audios.add(audio_hash)

return AudioStatistics(
total_duration_seconds=sum(audio_lengths),
min_duration_seconds=min(audio_lengths),
average_duration_seconds=sum(audio_lengths) / len(audio_lengths),
max_duration_seconds=max(audio_lengths),
unique_audios=len(unique_audios),
average_sampling_rate=(
sum(rate * count for rate, count in sampling_rates.items()) / len(audios)
),
sampling_rates=dict(sampling_rates),
)


def calculate_label_statistics(labels: list[int | list[int]]) -> LabelStatistics:
"""Calculate descriptive statistics for a list of labels.

Expand Down
8 changes: 8 additions & 0 deletions mteb/abstasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from mteb.models import EncoderProtocol, MTEBModels
from mteb.types import Array, HFSubset, ScoresDict
from mteb.types.statistics import (
AudioStatistics,
ImageStatistics,
LabelStatistics,
SplitDescriptiveStatistics,
TextStatistics,
)

from ._statistics_calculation import (
calculate_audio_statistics,
calculate_image_statistics,
calculate_label_statistics,
calculate_text_statistics,
Expand All @@ -45,6 +47,7 @@ class ClassificationDescriptiveStatistics(SplitDescriptiveStatistics):

text_statistics: Statistics for text
image_statistics: Statistics for images
audio_statistics: Statistics for audio
label_statistics: Statistics for labels
"""

Expand All @@ -53,6 +56,7 @@ class ClassificationDescriptiveStatistics(SplitDescriptiveStatistics):

text_statistics: TextStatistics | None
image_statistics: ImageStatistics | None
audio_statistics: AudioStatistics | None
label_statistics: LabelStatistics


Expand Down Expand Up @@ -469,6 +473,7 @@ def _calculate_descriptive_statistics_from_split(

image_statistics = None
text_statistics = None
audio_statistics = None
num_texts_in_train = None

if "image" in self.metadata.modalities:
Expand All @@ -480,6 +485,8 @@ def _calculate_descriptive_statistics_from_split(
if split != self.train_split
else None
)
if "audio" in self.metadata.modalities:
audio_statistics = calculate_audio_statistics(inputs)

label_statistics = calculate_label_statistics(label)

Expand All @@ -488,6 +495,7 @@ def _calculate_descriptive_statistics_from_split(
number_texts_intersect_with_train=num_texts_in_train,
text_statistics=text_statistics,
image_statistics=image_statistics,
audio_statistics=audio_statistics,
label_statistics=label_statistics,
)

Expand Down
9 changes: 8 additions & 1 deletion mteb/abstasks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from mteb.models import EncoderProtocol, MTEBModels
from mteb.types import Array, HFSubset, ScoresDict
from mteb.types.statistics import (
AudioStatistics,
ImageStatistics,
LabelStatistics,
SplitDescriptiveStatistics,
TextStatistics,
)

from ._statistics_calculation import (
calculate_audio_statistics,
calculate_image_statistics,
calculate_label_statistics,
calculate_text_statistics,
Expand Down Expand Up @@ -104,13 +106,15 @@ class ClusteringFastDescriptiveStatistics(SplitDescriptiveStatistics):

text_statistics: Statistics for text
image_statistics: Statistics for images
audio_statistics: Statistics for audio
labels_statistics: Statistics for labels
"""

num_samples: int

text_statistics: TextStatistics | None
image_statistics: ImageStatistics | None
audio_statistics: AudioStatistics | None
labels_statistics: LabelStatistics


Expand Down Expand Up @@ -271,19 +275,22 @@ def _calculate_descriptive_statistics_from_split(
if isinstance(labels[0], list):
labels = [item for sublist in labels for item in sublist]

text_statistics, image_statistics = None, None
text_statistics, image_statistics, audio_statistics = None, None, None
if "image" in self.metadata.modalities:
image_statistics = calculate_image_statistics(inputs)

if "text" in self.metadata.modalities:
text_statistics = calculate_text_statistics(inputs)
if "audio" in self.metadata.modalities:
audio_statistics = calculate_audio_statistics(inputs)

label_statistics = calculate_label_statistics(labels)

return ClusteringFastDescriptiveStatistics(
num_samples=len(inputs),
text_statistics=text_statistics,
image_statistics=image_statistics,
audio_statistics=audio_statistics,
labels_statistics=label_statistics,
)

Expand Down
10 changes: 9 additions & 1 deletion mteb/abstasks/clustering_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from mteb.models import EncoderProtocol, MTEBModels
from mteb.types import ScoresDict
from mteb.types.statistics import (
AudioStatistics,
ImageStatistics,
LabelStatistics,
SplitDescriptiveStatistics,
TextStatistics,
)

from ._statistics_calculation import (
calculate_audio_statistics,
calculate_image_statistics,
calculate_label_statistics,
calculate_text_statistics,
Expand All @@ -35,13 +37,15 @@ class ClusteringDescriptiveStatistics(SplitDescriptiveStatistics):

text_statistics: Statistics for text
image_statistics: Statistics for images
audio_statistics: Statistics for audio
label_statistics: Statistics for labels
"""

num_samples: int

text_statistics: TextStatistics | None
image_statistics: ImageStatistics | None
audio_statistics: AudioStatistics | None
label_statistics: LabelStatistics


Expand Down Expand Up @@ -214,19 +218,23 @@ def _calculate_descriptive_statistics_from_split(
if isinstance(labels[0], list):
labels = [item for sublist in labels for item in sublist]

text_statistics, image_statistics = None, None
text_statistics, image_statistics, audio_statistics = None, None, None
if "image" in self.metadata.modalities:
image_statistics = calculate_image_statistics(inputs)

if "text" in self.metadata.modalities:
text_statistics = calculate_text_statistics(inputs)

if "audio" in self.metadata.modalities:
audio_statistics = calculate_audio_statistics(inputs)

label_statistics = calculate_label_statistics(labels)

return ClusteringDescriptiveStatistics(
num_samples=len(inputs),
text_statistics=text_statistics,
image_statistics=image_statistics,
audio_statistics=audio_statistics,
label_statistics=label_statistics,
)

Expand Down
33 changes: 32 additions & 1 deletion mteb/abstasks/pair_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
PairClassificationDistances,
)
from mteb.abstasks._statistics_calculation import (
calculate_audio_statistics,
calculate_image_statistics,
calculate_label_statistics,
calculate_text_statistics,
Expand All @@ -21,6 +22,7 @@
from mteb.models.models_protocols import EncoderProtocol, MTEBModels
from mteb.types import PromptType
from mteb.types.statistics import (
AudioStatistics,
ImageStatistics,
LabelStatistics,
SplitDescriptiveStatistics,
Expand All @@ -39,7 +41,13 @@ class PairClassificationDescriptiveStatistics(SplitDescriptiveStatistics):
unique_pairs: Number of unique pairs

text1_statistics: Statistics for sentence1
image1_statistics: Statistics for image1
audio1_statistics: Statistics for audio1

text2_statistics: Statistics for sentence2
image2_statistics: Statistics for image2
audio2_statistics: Statistics for audio2

labels_statistics: Statistics for labels
"""

Expand All @@ -49,8 +57,10 @@ class PairClassificationDescriptiveStatistics(SplitDescriptiveStatistics):

text1_statistics: TextStatistics | None
image1_statistics: ImageStatistics | None
audio1_statistics: AudioStatistics | None
text2_statistics: TextStatistics | None
image2_statistics: ImageStatistics | None
audio2_statistics: AudioStatistics | None
labels_statistics: LabelStatistics


Expand Down Expand Up @@ -201,6 +211,8 @@ def _calculate_descriptive_statistics_from_split(
image1_statistics = None
image2_statistics = None
number_of_characters = None
audio1_statistics = None
audio2_statistics = None
unique_pairs = None
if self.metadata.modalities == ["text"]:
text1_statistics = calculate_text_statistics(input1)
Expand All @@ -211,7 +223,7 @@ def _calculate_descriptive_statistics_from_split(
)
unique_pairs = len(set(zip(input1, input2)))

elif self.metadata.modalities == ["image"]:
if self.metadata.modalities == ["image"]:
image1_statistics = calculate_image_statistics(input1)
image2_statistics = calculate_image_statistics(input2)

Expand All @@ -227,14 +239,33 @@ def _compute_image_hash(inputs: list) -> list[str]:
image_2_hashes = _compute_image_hash(input2)
unique_pairs = len(set(zip(image_1_hashes, image_2_hashes)))

if self.metadata.modalities == ["audio"]:
audio1_statistics = calculate_audio_statistics(input1)
audio2_statistics = calculate_audio_statistics(input2)

def _compute_audio_hash(inputs: list) -> list[str]:
hashes = set()
for audio in inputs:
array = audio["array"]
audio_bytes = array.tobytes()
audio_hash = hashlib.md5(audio_bytes).hexdigest()
hashes.add(audio_hash)
return list(hashes)

audio_1_hashes = _compute_audio_hash(input1)
audio_2_hashes = _compute_audio_hash(input2)
unique_pairs = len(set(zip(audio_1_hashes, audio_2_hashes)))

return PairClassificationDescriptiveStatistics(
num_samples=len(input1),
unique_pairs=unique_pairs,
number_of_characters=number_of_characters,
text1_statistics=text1_statistics,
image1_statistics=image1_statistics,
audio1_statistics=audio1_statistics,
text2_statistics=text2_statistics,
image2_statistics=image2_statistics,
audio2_statistics=audio2_statistics,
labels_statistics=calculate_label_statistics(labels),
)

Expand Down
Loading
Loading