-
Notifications
You must be signed in to change notification settings - Fork 555
Add ESC50 and zero-shot classification #2133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
111 commits
Select commit
Hold shift + click to select a range
56ce4ca
init audio
anime-sh 5d76e4d
some encoder related changes
anime-sh b8a45b2
some more abs task defs
anime-sh d4a34c1
evaluators and classification
anime-sh 72f526a
remove rahul changes to generate first PR
anime-sh a15e64c
make lint
anime-sh c5744bf
init audio
anime-sh 64ccf50
some encoder related changes
anime-sh 1a744c0
some more abs task defs
anime-sh c26ebae
evaluators and classification
anime-sh 1289d9b
remove rahul changes to generate first PR
anime-sh bb2b4d0
make lint
anime-sh 705664e
add dataset/tasks skeleton
silky1708 07eda3c
readd changes lost in rebase
anime-sh ebae179
add fsd50k
silky1708 d51c5d1
add task categories for audio
silky1708 e3b89fa
slight updates to fsd50k
silky1708 849323c
make lint
anime-sh 395b833
wav2vec2 model
anime-sh efd7095
add fsd50k metadata
silky1708 f97f9a3
rename folder
silky1708 6d61f3a
add metric
silky1708 fa61ea6
add torchaudio in req
anime-sh b03a28f
reigster wav2vec2 models
anime-sh 3b57aeb
Merge branch 'maeb' of https://github.com/anime-sh/mteb into maeb
RahulSChand e4aaf9d
fixes
silky1708 d3c20a0
add audio in valid task types
anime-sh 20a45ad
Merge branch 'maeb' of https://github.com/anime-sh/mteb into maeb
silky1708 c92073a
mock interface changes
anime-sh 1b97605
my 0 shot
RahulSChand b9a1c2a
Merge branch 'maeb' of https://github.com/anime-sh/mteb into zero_shot
RahulSChand 63bfaed
make lint
silky1708 2868359
rm audio clustering
silky1708 17949a0
wav2vec2 model revision update
silky1708 1865f84
rm comment
silky1708 3ad782e
rm test.py
silky1708 1ce34ac
add revisions to all wav2vec2 models
silky1708 cb57565
rm empty abstask files
silky1708 792fef3
rm empty evaluator files
silky1708 fdd8935
rm empty task files
silky1708 8def584
Update tests/test_tasks/test_all_abstasks.py
silky1708 26b8b7f
Update mteb/models/wav2vec2_models.py
silky1708 c256ac6
rm non-logReg evaluators for audio classification
silky1708 2379f16
lint
silky1708 babba47
fn name changed to convert_audio_from_numpy
silky1708 8b5f25b
rm mock tests for audio kNN classification
silky1708 64aeb3f
rm evaluators for audio kNN classification
silky1708 6b6ef78
fix imports
silky1708 35bf99a
fix audio kNN; make lint
silky1708 2977dd3
rm AbsTaskAudioClassification.py for later PR
silky1708 2267c98
added zero-shot loading model and dataset checked
RahulSChand 14e9270
remove commented code; reset changes to ClassificationEvaluator.py
silky1708 596d024
fix mock tasks for multilabel classification
silky1708 bd89b17
make lint
silky1708 66e2337
inherit Wrapper class
silky1708 d5ccc3d
add all languages supported by wav2vec2
silky1708 c840112
make lint
silky1708 99725d5
add script info to all languages
silky1708 420d420
make lint
silky1708 f9c16d8
before cleaning comments
RahulSChand b0f53ad
ESC and clap model. Tested 81 percent zero-shot numbers
RahulSChand 1ffaaca
fixed label names for ESC50-multilabel and removed comments
RahulSChand 496ebf5
recent changes
silky1708 7f7b255
Merge branch 'new-maeb' into maeb
anime-sh 4de650f
merge wav2vec2 + add updated logic for auto padding for fqd50k type d…
anime-sh ad03698
make lint remove uwanted files
anime-sh 2479aee
remove debug lines
anime-sh 2d81cea
remove esc50 refs
anime-sh f3821b9
changes for debugging
RahulSChand 8264ab0
lint changes and maeb main branch merge
RahulSChand 4033f68
fix mock tasks for multilabel
anime-sh fba981b
fix mock tasks for multilabel
anime-sh 5eec7d0
Revert "Merge branch 'maeb' into maeb" bad direct commit made to upst…
anime-sh 890f0db
fix model imports
anime-sh c3d338e
merge with maeb animesh branch
RahulSChand 205e9f8
zero shot mock test pass
RahulSChand 87e62bb
fqd50k cleaning
anime-sh e080877
fixed error in Image zero shot classfification
RahulSChand c9c36ba
update fsd50k
silky1708 ce4a8c0
change dataset
anime-sh 75f9975
eval subsets correctly
silky1708 66d437b
Merge branch 'maeb' of https://github.com/anime-sh/mteb into maeb
silky1708 0ebbbec
make lint and remove debug statements
silky1708 e312cff
clean print statements
silky1708 2c71b57
make lint
silky1708 d9e9d6b
update fsd2019 dataset
silky1708 3cb6b5d
remove init in AbsTaskAudioMultilabelClassification.py
silky1708 b94a447
add class parameters in AbsTaskAudioMultilabelClassification
silky1708 e9adeb6
merge remote changes in class init
silky1708 f963e79
inherit from multilingualtask for FSD2019Kaggle
silky1708 a90b2e6
make lint
silky1708 a2d31e7
update mock_tasks; make lint
silky1708 d6fdc00
remove train_split from fn parameters
silky1708 7ff54be
define fsd2019k to be multilingual
silky1708 1700f5a
inherit from MultilingualTask in fsd2019K
silky1708 c9f9aa4
Merge branch 'new-maeb' into maeb
anime-sh 1282cf2
fix tests
anime-sh 6f5e0ba
inherit correct multingial task class
anime-sh 4712cf0
remove MockAudioMultilabelClassificationLogRegTask
silky1708 11ba946
rm other instances of MockAudioMultilabelClassificationLogRegTask
silky1708 75b14df
merged maeb animesh branch
RahulSChand 9ca4837
merged with maeb upstream
RahulSChand d1bb88d
removed unncessary files
RahulSChand 2ecd849
removed unncrssary files
RahulSChand 5083bc1
removed uncrssary files part 3
RahulSChand 25481b9
deleted esc50 from multi label classification
RahulSChand 388849d
fixed errors
RahulSChand 429ac3b
fixed lintng, added precision and recall. Removed extra comments
RahulSChand 1d5e987
fixed double loading of model
RahulSChand a9b0605
filled in missing meta-data
RahulSChand dc2c28d
fixed linting
RahulSChand File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from typing import Any | ||
|
|
||
| from datasets import Dataset | ||
|
|
||
| from ...encoder_interface import Encoder | ||
| from ...evaluation.evaluators import AudioZeroshotClassificationEvaluator | ||
| from ..AbsTask import AbsTask, ScoresDict | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class AbsTaskAudioZeroshotClassification(AbsTask): | ||
| """Abstract class for ZeroshotClassification tasks | ||
| The similarity between audio and candidate text prompts, such as as an audio wav of a dog barking and candidate text prompts like "Sound of a dog barking" or "Sound of a airplane". | ||
|
|
||
| self.load_data() must generate a huggingface dataset with a split matching self.metadata_dict["eval_splits"], and assign it to self.dataset. It must contain the following columns: | ||
| image: list of Image.Image | ||
| labels: list of int | ||
| """ | ||
|
|
||
| audio_column_name: str = "audio" | ||
| label_column_name: str = "target" | ||
|
|
||
| def __init__(self, **kwargs): | ||
| super().__init__(**kwargs) | ||
|
|
||
| def _add_main_score(self, scores) -> None: | ||
| scores["main_score"] = scores[self.metadata.main_score] | ||
|
|
||
| def _calculate_metrics_from_split( | ||
| self, split: str, hf_subset: str | None = None, compute_overall: bool = False | ||
| ): | ||
| pass | ||
|
|
||
| def _evaluate_subset( | ||
| self, | ||
| model: Encoder, | ||
| dataset: Dataset, | ||
| *, | ||
| encode_kwargs: dict[str, Any] = {}, | ||
| **kwargs, | ||
| ) -> ScoresDict: | ||
| candidate_labels = self.get_candidate_labels() | ||
|
|
||
| evaluator = AudioZeroshotClassificationEvaluator( | ||
| dataset, | ||
| self.audio_column_name, | ||
| self.label_column_name, | ||
| candidate_labels, | ||
| task_name=self.metadata.name, | ||
| **kwargs, | ||
| ) | ||
| metrics = evaluator(model, encode_kwargs=encode_kwargs) | ||
|
|
||
| scores = { | ||
| "accuracy": metrics["accuracy"], | ||
| "f1": metrics["f1"], | ||
| "f1_weighted": metrics["f1_weighted"], | ||
| "precision": metrics["precision"], | ||
| "recall": metrics["recall"], | ||
| } | ||
| self._add_main_score(scores) | ||
| return scores | ||
|
|
||
| def get_candidate_labels(self) -> list[str]: | ||
| """Return the text candidates for zeroshot classification""" | ||
| raise NotImplementedError("This method should be overridden by subclasses") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
mteb/evaluation/evaluators/Audio/ZeroshotClassificationEvaluator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import io | ||
| import logging | ||
| import math | ||
| import os | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torchaudio | ||
| from sklearn import metrics | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from mteb.encoder_interface import Encoder | ||
|
|
||
| from ..Evaluator import Evaluator | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class AudioDataset(torch.utils.data.Dataset): | ||
| def __init__(self, hf_dataset, audio_column_name: str = "image", transform=None): | ||
| self.dataset = hf_dataset | ||
| self.transform = transform | ||
| self.audio_column_name = audio_column_name | ||
|
|
||
| def __len__(self): | ||
| return len(self.dataset) | ||
|
|
||
| def __getitem__(self, idx): | ||
| audio = self.dataset[idx][self.audio_column_name] | ||
| if isinstance(audio, bytes): | ||
| waveform, sample_rate = torchaudio.load(io.BytesIO(audio)) | ||
| elif isinstance(audio, str): | ||
| # Assuming audio is a file path | ||
| waveform, sample_rate = torchaudio.load(audio) | ||
| else: | ||
| # Assume audio is already a tensor or in a usable format | ||
| waveform = audio | ||
| if self.transform is not None: | ||
| waveform = self.transform(waveform) | ||
| return waveform | ||
|
|
||
|
|
||
| def custom_collate_fn(batch): | ||
| return batch | ||
|
|
||
|
|
||
| class AudioZeroshotClassificationEvaluator(Evaluator): | ||
| def __init__( | ||
| self, | ||
| dataset, | ||
| audio_column_name: str, | ||
| label_column_name: str, | ||
| candidate_labels: list[str], | ||
| task_name: str | None = None, | ||
| transform=None, | ||
| batch_size: int = 32, | ||
| **kwargs, | ||
| ): | ||
| """Initialize zero-shot audio classification evaluator. | ||
|
|
||
| Args: | ||
| dataset: HuggingFace dataset containing audio data | ||
| audio_column_name: Name of column containing audio data | ||
| label_column_name: Name of column containing label indices | ||
| candidate_labels: List of text descriptions for possible classes | ||
| task_name: Optional name of the task | ||
| transform: Optional audio transforms | ||
| batch_size: Batch size for processing | ||
| **kwargs: Additional keyword arguments | ||
| """ | ||
| super().__init__(**kwargs) | ||
| self.dataset = AudioDataset( | ||
| dataset, audio_column_name=audio_column_name, transform=transform | ||
| ) | ||
| self.labels = dataset[label_column_name] | ||
| self.candidate_labels = candidate_labels | ||
| self.task_name = task_name | ||
| self.batch_size = batch_size | ||
|
|
||
| def __call__( | ||
| self, model: Encoder, *, encode_kwargs: dict[str, Any] = {} | ||
| ) -> dict[str, float]: | ||
| """Evaluate zero-shot classification performance.""" | ||
| logger.info("Getting text embeddings for candidate labels...") | ||
|
|
||
| text_embeddings = model.get_text_embeddings(self.candidate_labels) | ||
|
|
||
| logger.info("Processing audio data...") | ||
| dataloader = DataLoader( | ||
| self.dataset, | ||
| batch_size=encode_kwargs.get("batch_size", self.batch_size), | ||
| collate_fn=custom_collate_fn, | ||
| num_workers=min(math.floor(os.cpu_count() / 2), 16), | ||
| ) | ||
|
|
||
| audio_embeddings = model.get_audio_embeddings(dataloader) | ||
|
|
||
| # Calculate similarity scores | ||
| similarity = ( | ||
| torch.from_numpy(audio_embeddings) @ torch.from_numpy(text_embeddings).T | ||
| ) | ||
|
|
||
| predictions = similarity.argmax(dim=1).cpu().numpy() | ||
|
|
||
| # Calculate metrics | ||
| scores = { | ||
| "accuracy": metrics.accuracy_score(self.labels, predictions), | ||
| "f1": metrics.f1_score(self.labels, predictions, average="macro"), | ||
| "f1_weighted": metrics.f1_score(self.labels, predictions, average="macro"), | ||
| "precision": metrics.precision_score( | ||
| self.labels, predictions, average="macro" | ||
| ), | ||
| "recall": metrics.recall_score(self.labels, predictions, average="macro"), | ||
| } | ||
|
|
||
| return scores | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.