Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
56ce4ca
init audio
anime-sh Feb 12, 2025
5d76e4d
some encoder related changes
anime-sh Feb 16, 2025
b8a45b2
some more abs task defs
anime-sh Feb 17, 2025
d4a34c1
evaluators and classification
anime-sh Feb 17, 2025
72f526a
remove rahul changes to generate first PR
anime-sh Feb 17, 2025
a15e64c
make lint
anime-sh Feb 17, 2025
c5744bf
init audio
anime-sh Feb 12, 2025
64ccf50
some encoder related changes
anime-sh Feb 16, 2025
1a744c0
some more abs task defs
anime-sh Feb 17, 2025
c26ebae
evaluators and classification
anime-sh Feb 17, 2025
1289d9b
remove rahul changes to generate first PR
anime-sh Feb 17, 2025
bb2b4d0
make lint
anime-sh Feb 17, 2025
705664e
add dataset/tasks skeleton
silky1708 Feb 21, 2025
07eda3c
readd changes lost in rebase
anime-sh Feb 21, 2025
ebae179
add fsd50k
silky1708 Feb 21, 2025
d51c5d1
add task categories for audio
silky1708 Feb 21, 2025
e3b89fa
slight updates to fsd50k
silky1708 Feb 21, 2025
849323c
make lint
anime-sh Feb 21, 2025
395b833
wav2vec2 model
anime-sh Feb 21, 2025
efd7095
add fsd50k metadata
silky1708 Feb 21, 2025
f97f9a3
rename folder
silky1708 Feb 21, 2025
6d61f3a
add metric
silky1708 Feb 21, 2025
fa61ea6
add torchaudio in req
anime-sh Feb 21, 2025
b03a28f
reigster wav2vec2 models
anime-sh Feb 21, 2025
3b57aeb
Merge branch 'maeb' of https://github.com/anime-sh/mteb into maeb
RahulSChand Feb 21, 2025
e4aaf9d
fixes
silky1708 Feb 21, 2025
d3c20a0
add audio in valid task types
anime-sh Feb 21, 2025
20a45ad
Merge branch 'maeb' of https://github.com/anime-sh/mteb into maeb
silky1708 Feb 21, 2025
c92073a
mock interface changes
anime-sh Feb 21, 2025
1b97605
my 0 shot
RahulSChand Feb 21, 2025
b9a1c2a
Merge branch 'maeb' of https://github.com/anime-sh/mteb into zero_shot
RahulSChand Feb 21, 2025
63bfaed
make lint
silky1708 Feb 21, 2025
2868359
rm audio clustering
silky1708 Feb 21, 2025
17949a0
wav2vec2 model revision update
silky1708 Feb 21, 2025
1865f84
rm comment
silky1708 Feb 21, 2025
3ad782e
rm test.py
silky1708 Feb 21, 2025
1ce34ac
add revisions to all wav2vec2 models
silky1708 Feb 21, 2025
cb57565
rm empty abstask files
silky1708 Feb 21, 2025
792fef3
rm empty evaluator files
silky1708 Feb 21, 2025
fdd8935
rm empty task files
silky1708 Feb 21, 2025
8def584
Update tests/test_tasks/test_all_abstasks.py
silky1708 Feb 21, 2025
26b8b7f
Update mteb/models/wav2vec2_models.py
silky1708 Feb 21, 2025
c256ac6
rm non-logReg evaluators for audio classification
silky1708 Feb 21, 2025
2379f16
lint
silky1708 Feb 21, 2025
babba47
fn name changed to convert_audio_from_numpy
silky1708 Feb 21, 2025
8b5f25b
rm mock tests for audio kNN classification
silky1708 Feb 21, 2025
64aeb3f
rm evaluators for audio kNN classification
silky1708 Feb 21, 2025
6b6ef78
fix imports
silky1708 Feb 21, 2025
35bf99a
fix audio kNN; make lint
silky1708 Feb 21, 2025
2977dd3
rm AbsTaskAudioClassification.py for later PR
silky1708 Feb 21, 2025
2267c98
added zero-shot loading model and dataset checked
RahulSChand Feb 21, 2025
14e9270
remove commented code; reset changes to ClassificationEvaluator.py
silky1708 Feb 24, 2025
596d024
fix mock tasks for multilabel classification
silky1708 Feb 24, 2025
bd89b17
make lint
silky1708 Feb 24, 2025
66e2337
inherit Wrapper class
silky1708 Feb 24, 2025
d5ccc3d
add all languages supported by wav2vec2
silky1708 Feb 24, 2025
c840112
make lint
silky1708 Feb 24, 2025
99725d5
add script info to all languages
silky1708 Feb 24, 2025
420d420
make lint
silky1708 Feb 24, 2025
f9c16d8
before cleaning comments
RahulSChand Feb 25, 2025
b0f53ad
ESC and clap model. Tested 81 percent zero-shot numbers
RahulSChand Feb 25, 2025
1ffaaca
fixed label names for ESC50-multilabel and removed comments
RahulSChand Feb 25, 2025
496ebf5
recent changes
silky1708 Mar 1, 2025
7f7b255
Merge branch 'new-maeb' into maeb
anime-sh Mar 1, 2025
4de650f
merge wav2vec2 + add updated logic for auto padding for fqd50k type d…
anime-sh Mar 1, 2025
ad03698
make lint remove uwanted files
anime-sh Mar 1, 2025
2479aee
remove debug lines
anime-sh Mar 1, 2025
2d81cea
remove esc50 refs
anime-sh Mar 1, 2025
f3821b9
changes for debugging
RahulSChand Mar 1, 2025
8264ab0
lint changes and maeb main branch merge
RahulSChand Mar 1, 2025
4033f68
fix mock tasks for multilabel
anime-sh Mar 1, 2025
fba981b
fix mock tasks for multilabel
anime-sh Mar 1, 2025
5eec7d0
Revert "Merge branch 'maeb' into maeb" bad direct commit made to upst…
anime-sh Mar 1, 2025
890f0db
fix model imports
anime-sh Mar 1, 2025
c3d338e
merge with maeb animesh branch
RahulSChand Mar 1, 2025
205e9f8
zero shot mock test pass
RahulSChand Mar 1, 2025
87e62bb
fqd50k cleaning
anime-sh Mar 1, 2025
e080877
fixed error in Image zero shot classfification
RahulSChand Mar 1, 2025
c9c36ba
update fsd50k
silky1708 Mar 1, 2025
ce4a8c0
change dataset
anime-sh Mar 1, 2025
75f9975
eval subsets correctly
silky1708 Mar 1, 2025
66d437b
Merge branch 'maeb' of https://github.com/anime-sh/mteb into maeb
silky1708 Mar 1, 2025
0ebbbec
make lint and remove debug statements
silky1708 Mar 1, 2025
e312cff
clean print statements
silky1708 Mar 1, 2025
2c71b57
make lint
silky1708 Mar 1, 2025
d9e9d6b
update fsd2019 dataset
silky1708 Mar 1, 2025
3cb6b5d
remove init in AbsTaskAudioMultilabelClassification.py
silky1708 Mar 2, 2025
b94a447
add class parameters in AbsTaskAudioMultilabelClassification
silky1708 Mar 2, 2025
e9adeb6
merge remote changes in class init
silky1708 Mar 2, 2025
f963e79
inherit from multilingualtask for FSD2019Kaggle
silky1708 Mar 2, 2025
a90b2e6
make lint
silky1708 Mar 2, 2025
a2d31e7
update mock_tasks; make lint
silky1708 Mar 2, 2025
d6fdc00
remove train_split from fn parameters
silky1708 Mar 2, 2025
7ff54be
define fsd2019k to be multilingual
silky1708 Mar 2, 2025
1700f5a
inherit from MultilingualTask in fsd2019K
silky1708 Mar 2, 2025
c9f9aa4
Merge branch 'new-maeb' into maeb
anime-sh Mar 2, 2025
1282cf2
fix tests
anime-sh Mar 3, 2025
6f5e0ba
inherit correct multingial task class
anime-sh Mar 3, 2025
4712cf0
remove MockAudioMultilabelClassificationLogRegTask
silky1708 Mar 3, 2025
11ba946
rm other instances of MockAudioMultilabelClassificationLogRegTask
silky1708 Mar 3, 2025
75b14df
merged maeb animesh branch
RahulSChand Mar 5, 2025
9ca4837
merged with maeb upstream
RahulSChand Mar 5, 2025
d1bb88d
removed unncessary files
RahulSChand Mar 5, 2025
2ecd849
removed unncrssary files
RahulSChand Mar 5, 2025
5083bc1
removed uncrssary files part 3
RahulSChand Mar 5, 2025
25481b9
deleted esc50 from multi label classification
RahulSChand Mar 5, 2025
388849d
fixed errors
RahulSChand Mar 5, 2025
429ac3b
fixed lintng, added precision and recall. Removed extra comments
RahulSChand Mar 5, 2025
1d5e987
fixed double loading of model
RahulSChand Mar 5, 2025
a9b0605
filled in missing meta-data
RahulSChand Mar 5, 2025
dc2c28d
fixed linting
RahulSChand Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions mteb/abstasks/Audio/AbsTaskAudioZeroshotClassification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

import logging
from typing import Any

from datasets import Dataset

from ...encoder_interface import Encoder
from ...evaluation.evaluators import AudioZeroshotClassificationEvaluator
from ..AbsTask import AbsTask, ScoresDict

logger = logging.getLogger(__name__)


class AbsTaskAudioZeroshotClassification(AbsTask):
"""Abstract class for ZeroshotClassification tasks
The similarity between audio and candidate text prompts, such as as an audio wav of a dog barking and candidate text prompts like "Sound of a dog barking" or "Sound of a airplane".

self.load_data() must generate a huggingface dataset with a split matching self.metadata_dict["eval_splits"], and assign it to self.dataset. It must contain the following columns:
image: list of Image.Image
labels: list of int
"""

audio_column_name: str = "audio"
label_column_name: str = "target"

def __init__(self, **kwargs):
super().__init__(**kwargs)

def _add_main_score(self, scores) -> None:
scores["main_score"] = scores[self.metadata.main_score]

def _calculate_metrics_from_split(
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
):
pass

def _evaluate_subset(
self,
model: Encoder,
dataset: Dataset,
*,
encode_kwargs: dict[str, Any] = {},
**kwargs,
) -> ScoresDict:
candidate_labels = self.get_candidate_labels()

evaluator = AudioZeroshotClassificationEvaluator(
dataset,
self.audio_column_name,
self.label_column_name,
candidate_labels,
task_name=self.metadata.name,
**kwargs,
)
metrics = evaluator(model, encode_kwargs=encode_kwargs)

scores = {
"accuracy": metrics["accuracy"],
"f1": metrics["f1"],
"f1_weighted": metrics["f1_weighted"],
"precision": metrics["precision"],
"recall": metrics["recall"],
}
self._add_main_score(scores)
return scores

def get_candidate_labels(self) -> list[str]:
"""Return the text candidates for zeroshot classification"""
raise NotImplementedError("This method should be overridden by subclasses")
1 change: 1 addition & 0 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"VisualSTS",
"ZeroShotClassification",
"AudioMultilabelClassification",
"AudioZeroshotClassification",
]


Expand Down
118 changes: 118 additions & 0 deletions mteb/evaluation/evaluators/Audio/ZeroshotClassificationEvaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

import io
import logging
import math
import os
from typing import Any

import torch
import torchaudio
from sklearn import metrics
from torch.utils.data import DataLoader

from mteb.encoder_interface import Encoder

from ..Evaluator import Evaluator

logger = logging.getLogger(__name__)


class AudioDataset(torch.utils.data.Dataset):
def __init__(self, hf_dataset, audio_column_name: str = "image", transform=None):
self.dataset = hf_dataset
self.transform = transform
self.audio_column_name = audio_column_name

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
audio = self.dataset[idx][self.audio_column_name]
if isinstance(audio, bytes):
waveform, sample_rate = torchaudio.load(io.BytesIO(audio))
elif isinstance(audio, str):
# Assuming audio is a file path
waveform, sample_rate = torchaudio.load(audio)
else:
# Assume audio is already a tensor or in a usable format
waveform = audio
if self.transform is not None:
waveform = self.transform(waveform)
return waveform


def custom_collate_fn(batch):
return batch


class AudioZeroshotClassificationEvaluator(Evaluator):
def __init__(
self,
dataset,
audio_column_name: str,
label_column_name: str,
candidate_labels: list[str],
task_name: str | None = None,
transform=None,
batch_size: int = 32,
**kwargs,
):
"""Initialize zero-shot audio classification evaluator.

Args:
dataset: HuggingFace dataset containing audio data
audio_column_name: Name of column containing audio data
label_column_name: Name of column containing label indices
candidate_labels: List of text descriptions for possible classes
task_name: Optional name of the task
transform: Optional audio transforms
batch_size: Batch size for processing
**kwargs: Additional keyword arguments
"""
super().__init__(**kwargs)
self.dataset = AudioDataset(
dataset, audio_column_name=audio_column_name, transform=transform
)
self.labels = dataset[label_column_name]
self.candidate_labels = candidate_labels
self.task_name = task_name
self.batch_size = batch_size

def __call__(
self, model: Encoder, *, encode_kwargs: dict[str, Any] = {}
) -> dict[str, float]:
"""Evaluate zero-shot classification performance."""
logger.info("Getting text embeddings for candidate labels...")

text_embeddings = model.get_text_embeddings(self.candidate_labels)

logger.info("Processing audio data...")
dataloader = DataLoader(
self.dataset,
batch_size=encode_kwargs.get("batch_size", self.batch_size),
collate_fn=custom_collate_fn,
num_workers=min(math.floor(os.cpu_count() / 2), 16),
)

audio_embeddings = model.get_audio_embeddings(dataloader)

# Calculate similarity scores
similarity = (
torch.from_numpy(audio_embeddings) @ torch.from_numpy(text_embeddings).T
)

predictions = similarity.argmax(dim=1).cpu().numpy()

# Calculate metrics
scores = {
"accuracy": metrics.accuracy_score(self.labels, predictions),
"f1": metrics.f1_score(self.labels, predictions, average="macro"),
"f1_weighted": metrics.f1_score(self.labels, predictions, average="macro"),
"precision": metrics.precision_score(
self.labels, predictions, average="macro"
),
"recall": metrics.recall_score(self.labels, predictions, average="macro"),
}

return scores
1 change: 1 addition & 0 deletions mteb/evaluation/evaluators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .Audio.Any2AnyRetrievalEvaluator import *
from .Audio.ClassificationEvaluator import *
from .Audio.ClusteringEvaluator import *
from .Audio.ZeroshotClassificationEvaluator import *
from .BitextMiningEvaluator import *
from .ClassificationEvaluator import *
from .ClusteringEvaluator import *
Expand Down
Loading
Loading