diff --git a/mteb/abstasks/Audio/AbsTaskAudioZeroshotClassification.py b/mteb/abstasks/Audio/AbsTaskAudioZeroshotClassification.py new file mode 100644 index 0000000000..c5f4086d1d --- /dev/null +++ b/mteb/abstasks/Audio/AbsTaskAudioZeroshotClassification.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import logging +from typing import Any + +from datasets import Dataset + +from ...encoder_interface import Encoder +from ...evaluation.evaluators import AudioZeroshotClassificationEvaluator +from ..AbsTask import AbsTask, ScoresDict + +logger = logging.getLogger(__name__) + + +class AbsTaskAudioZeroshotClassification(AbsTask): + """Abstract class for ZeroshotClassification tasks + The similarity between audio and candidate text prompts, such as as an audio wav of a dog barking and candidate text prompts like "Sound of a dog barking" or "Sound of a airplane". + + self.load_data() must generate a huggingface dataset with a split matching self.metadata_dict["eval_splits"], and assign it to self.dataset. It must contain the following columns: + image: list of Image.Image + labels: list of int + """ + + audio_column_name: str = "audio" + label_column_name: str = "target" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _add_main_score(self, scores) -> None: + scores["main_score"] = scores[self.metadata.main_score] + + def _calculate_metrics_from_split( + self, split: str, hf_subset: str | None = None, compute_overall: bool = False + ): + pass + + def _evaluate_subset( + self, + model: Encoder, + dataset: Dataset, + *, + encode_kwargs: dict[str, Any] = {}, + **kwargs, + ) -> ScoresDict: + candidate_labels = self.get_candidate_labels() + + evaluator = AudioZeroshotClassificationEvaluator( + dataset, + self.audio_column_name, + self.label_column_name, + candidate_labels, + task_name=self.metadata.name, + **kwargs, + ) + metrics = evaluator(model, encode_kwargs=encode_kwargs) + + scores = { + "accuracy": metrics["accuracy"], + "f1": metrics["f1"], + "f1_weighted": metrics["f1_weighted"], + "precision": metrics["precision"], + "recall": metrics["recall"], + } + self._add_main_score(scores) + return scores + + def get_candidate_labels(self) -> list[str]: + """Return the text candidates for zeroshot classification""" + raise NotImplementedError("This method should be overridden by subclasses") diff --git a/mteb/abstasks/TaskMetadata.py b/mteb/abstasks/TaskMetadata.py index 79542e4bce..6834fcaa70 100644 --- a/mteb/abstasks/TaskMetadata.py +++ b/mteb/abstasks/TaskMetadata.py @@ -125,6 +125,7 @@ "VisualSTS", "ZeroShotClassification", "AudioMultilabelClassification", + "AudioZeroshotClassification", ] diff --git a/mteb/evaluation/evaluators/Audio/ZeroshotClassificationEvaluator.py b/mteb/evaluation/evaluators/Audio/ZeroshotClassificationEvaluator.py new file mode 100644 index 0000000000..4e03b293af --- /dev/null +++ b/mteb/evaluation/evaluators/Audio/ZeroshotClassificationEvaluator.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import io +import logging +import math +import os +from typing import Any + +import torch +import torchaudio +from sklearn import metrics +from torch.utils.data import DataLoader + +from mteb.encoder_interface import Encoder + +from ..Evaluator import Evaluator + +logger = logging.getLogger(__name__) + + +class AudioDataset(torch.utils.data.Dataset): + def __init__(self, hf_dataset, audio_column_name: str = "image", transform=None): + self.dataset = hf_dataset + self.transform = transform + self.audio_column_name = audio_column_name + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + audio = self.dataset[idx][self.audio_column_name] + if isinstance(audio, bytes): + waveform, sample_rate = torchaudio.load(io.BytesIO(audio)) + elif isinstance(audio, str): + # Assuming audio is a file path + waveform, sample_rate = torchaudio.load(audio) + else: + # Assume audio is already a tensor or in a usable format + waveform = audio + if self.transform is not None: + waveform = self.transform(waveform) + return waveform + + +def custom_collate_fn(batch): + return batch + + +class AudioZeroshotClassificationEvaluator(Evaluator): + def __init__( + self, + dataset, + audio_column_name: str, + label_column_name: str, + candidate_labels: list[str], + task_name: str | None = None, + transform=None, + batch_size: int = 32, + **kwargs, + ): + """Initialize zero-shot audio classification evaluator. + + Args: + dataset: HuggingFace dataset containing audio data + audio_column_name: Name of column containing audio data + label_column_name: Name of column containing label indices + candidate_labels: List of text descriptions for possible classes + task_name: Optional name of the task + transform: Optional audio transforms + batch_size: Batch size for processing + **kwargs: Additional keyword arguments + """ + super().__init__(**kwargs) + self.dataset = AudioDataset( + dataset, audio_column_name=audio_column_name, transform=transform + ) + self.labels = dataset[label_column_name] + self.candidate_labels = candidate_labels + self.task_name = task_name + self.batch_size = batch_size + + def __call__( + self, model: Encoder, *, encode_kwargs: dict[str, Any] = {} + ) -> dict[str, float]: + """Evaluate zero-shot classification performance.""" + logger.info("Getting text embeddings for candidate labels...") + + text_embeddings = model.get_text_embeddings(self.candidate_labels) + + logger.info("Processing audio data...") + dataloader = DataLoader( + self.dataset, + batch_size=encode_kwargs.get("batch_size", self.batch_size), + collate_fn=custom_collate_fn, + num_workers=min(math.floor(os.cpu_count() / 2), 16), + ) + + audio_embeddings = model.get_audio_embeddings(dataloader) + + # Calculate similarity scores + similarity = ( + torch.from_numpy(audio_embeddings) @ torch.from_numpy(text_embeddings).T + ) + + predictions = similarity.argmax(dim=1).cpu().numpy() + + # Calculate metrics + scores = { + "accuracy": metrics.accuracy_score(self.labels, predictions), + "f1": metrics.f1_score(self.labels, predictions, average="macro"), + "f1_weighted": metrics.f1_score(self.labels, predictions, average="macro"), + "precision": metrics.precision_score( + self.labels, predictions, average="macro" + ), + "recall": metrics.recall_score(self.labels, predictions, average="macro"), + } + + return scores diff --git a/mteb/evaluation/evaluators/__init__.py b/mteb/evaluation/evaluators/__init__.py index 82459cc5dd..eecbe79e8f 100644 --- a/mteb/evaluation/evaluators/__init__.py +++ b/mteb/evaluation/evaluators/__init__.py @@ -3,6 +3,7 @@ from .Audio.Any2AnyRetrievalEvaluator import * from .Audio.ClassificationEvaluator import * from .Audio.ClusteringEvaluator import * +from .Audio.ZeroshotClassificationEvaluator import * from .BitextMiningEvaluator import * from .ClassificationEvaluator import * from .ClusteringEvaluator import * diff --git a/mteb/models/clap_models.py b/mteb/models/clap_models.py new file mode 100644 index 0000000000..351d943ccc --- /dev/null +++ b/mteb/models/clap_models.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from collections.abc import Iterable +from functools import partial +from typing import Any + +import numpy as np +import torch +import torchaudio +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ClapModel, ClapProcessor, pipeline + +from mteb.encoder_interface import AudioBatch, AudioData, PromptType +from mteb.model_meta import ModelMeta + + +class ClapZeroShotWrapper: + def __init__( + self, + model_name: str = "laion/clap_htsat_fused", + device: str = "cuda" if torch.cuda.is_available() else "cpu", + **kwargs: Any, + ): + self.model_name = model_name + self.device = device + self.model = ClapModel.from_pretrained(model_name).to(self.device) + self.processor = ClapProcessor.from_pretrained(model_name) + self.sampling_rate = self.processor.feature_extractor.sampling_rate + + self.pipeline = pipeline( + task="zero-shot-audio-classification", + model=self.model, + feature_extractor=self.processor.feature_extractor, # Add the feature extractor + tokenizer=self.processor.tokenizer, # Add the tokenizer + device=device, + **kwargs, + ) + + def _process_audio(self, audio: AudioBatch) -> list[torch.Tensor]: + processed_audio = [] + + if isinstance(audio, DataLoader): + for batch in audio: + processed_audio.extend(self._handle_batch(batch)) + else: + processed_audio = self._handle_batch(audio) + + return processed_audio + + def _handle_batch( + self, batch: AudioData | Iterable[tuple[AudioData, str]] + ) -> list[torch.Tensor]: + waveforms = [] + + if isinstance(batch, tuple): # Handle (audio, metadata) tuples + for audio, _ in batch: + waveforms.append(self._convert_audio(audio)) + else: + for item in batch: + if isinstance(item, dict): + if "array" in item: + audio = item["array"] + # Convert to torch tensor and ensure float32 + audio = ( + torch.from_numpy(audio).float() + if isinstance(audio, np.ndarray) + else audio.float() + ) + if item["sampling_rate"] != self.sampling_rate: + resampler = torchaudio.transforms.Resample( + item["sampling_rate"], self.sampling_rate + ) + audio = resampler(audio) + waveforms.append(self._convert_audio(audio)) + elif "path" in item: + waveforms.append(self._load_audio_file(item["path"])) + elif isinstance(item, (np.ndarray, torch.Tensor)): + waveforms.append(self._convert_audio(item)) + elif isinstance(item, str): + waveforms.append(self._load_audio_file(item)) + + return waveforms + + def _convert_audio(self, audio: AudioData) -> torch.Tensor: + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + return audio.squeeze().float() # Ensure float32 + + def _load_audio_file(self, path: str) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(path) + waveform = waveform.float() # Ensure float32 + if sample_rate != self.sampling_rate: + resampler = torchaudio.transforms.Resample(sample_rate, self.sampling_rate) + waveform = resampler(waveform) + return waveform.squeeze() + + def _convert_audio(self, audio: AudioData) -> torch.Tensor: + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + return audio.squeeze() + + def _load_audio_file(self, path: str) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(path) + if sample_rate != self.sampling_rate: + resampler = torchaudio.transforms.Resample(sample_rate, self.sampling_rate) + waveform = resampler(waveform) + return waveform.squeeze() + + def get_audio_embeddings( + self, + audio: AudioBatch, + **kwargs: Any, + ) -> np.ndarray: + all_features = [] + target_sampling_rate = 48000 # CLAP's expected sampling rate + + if isinstance(audio, DataLoader): + # Process all batches + for batch in tqdm(audio, desc="Processing audio batches"): + batch_features = [] + # Process each item in the batch individually to avoid memory issues + for item in batch: + inputs = self.pipeline.feature_extractor( + [item["array"]], + sampling_rate=target_sampling_rate, + return_tensors="pt", + padding=True, + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + audio_features = self.pipeline.model.get_audio_features( + **inputs + ) + audio_features = audio_features / audio_features.norm( + dim=-1, keepdim=True + ) + batch_features.append(audio_features.cpu().numpy()) + + all_features.extend(batch_features) + + return np.vstack(all_features) + else: + # Process single batch + batch_features = [] + for item in audio: + inputs = self.pipeline.feature_extractor( + [item["array"]], + sampling_rate=target_sampling_rate, + return_tensors="pt", + padding=True, + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + audio_features = self.pipeline.model.get_audio_features(**inputs) + audio_features = audio_features / audio_features.norm( + dim=-1, keepdim=True + ) + batch_features.append(audio_features.cpu().numpy()) + + return np.vstack(batch_features) + + def get_text_embeddings( + self, + texts: list[str], + **kwargs: Any, + ) -> np.ndarray: + inputs = self.processor(text=texts, return_tensors="pt", padding=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + text_features = self.model.get_text_features(**inputs) + # Normalize embeddings + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + return text_features.cpu().numpy() + + def encode( + self, + inputs: AudioBatch | list[str], + *, + task_name: str, + prompt_type: PromptType | None = None, + **kwargs: Any, + ) -> np.ndarray: + if isinstance(inputs[0], str): + return self.get_text_embeddings(inputs) + return self.get_audio_embeddings(inputs) + + +# Model metadata +clap_htsat_fused = ModelMeta( + loader=partial(ClapZeroShotWrapper, model_name="laion/clap-htsat-fused"), + name="laion/clap-htsat-fused", + languages=["en"], + revision="main", + release_date="2023-05-22", + modalities=["audio", "text"], + n_parameters=153_507_530, # Calculated using torch.numel(model.parameters()) + memory_usage_mb=586, # Calculated using model.calculate_memory_usage_mb() + max_tokens=float("inf"), + embed_dim=512, # The project_dim in config.json is 512 + license="MIT", + open_weights=True, + public_training_code="https://github.com/LAION-AI/CLAP", + public_training_data="LAION-Audio-630K", + framework=["PyTorch"], + reference="https://huggingface.co/laion/clap_htsat_fused", + similarity_fn_name="cosine", + use_instructions=False, + training_datasets={"LAION-Audio-630K": ["https://laion.ai/blog/laion-audio-630k/"]}, +) diff --git a/mteb/models/overview.py b/mteb/models/overview.py index 091a3c4ebb..b5fa709530 100644 --- a/mteb/models/overview.py +++ b/mteb/models/overview.py @@ -20,6 +20,7 @@ blip_models, bm25, cde_models, + clap_models, clip_models, cohere_models, cohere_v, @@ -138,6 +139,7 @@ voyage_models, fa_models, wav2vec2_models, + clap_models, ] MODEL_REGISTRY = {} diff --git a/mteb/tasks/Audio/AudioZeroshotClassification/__init__.py b/mteb/tasks/Audio/AudioZeroshotClassification/__init__.py new file mode 100644 index 0000000000..20efe80442 --- /dev/null +++ b/mteb/tasks/Audio/AudioZeroshotClassification/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .eng.ESC50 import * diff --git a/mteb/tasks/Audio/AudioZeroshotClassification/eng/ESC50.py b/mteb/tasks/Audio/AudioZeroshotClassification/eng/ESC50.py new file mode 100644 index 0000000000..02e363db6c --- /dev/null +++ b/mteb/tasks/Audio/AudioZeroshotClassification/eng/ESC50.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from mteb.abstasks.Audio.AbsTaskAudioZeroshotClassification import ( + AbsTaskAudioZeroshotClassification, +) +from mteb.abstasks.TaskMetadata import TaskMetadata + + +class ESC50ZeroshotClassification(AbsTaskAudioZeroshotClassification): + metadata = TaskMetadata( + name="ESC50_Zeroshot", + description="Environmental Sound Classification Dataset.", + reference="https://huggingface.co/datasets/ashraq/esc50", + dataset={ + "path": "ashraq/esc50", + "revision": "e3e2a63ffff66b9a9735524551e3818e96af03ee", + }, + type="AudioZeroshotClassification", + category="a2a", + eval_splits=["train"], + eval_langs=["eng-Latn"], + main_score="accuracy", + date=("2023-01-07", "2023-01-07"), + domains=[ + "Spoken" + ], # Replace with appropriate domain from allowed list?? No appropriate domain name is available + task_subtypes=["Environment Sound Classification"], + license="cc-by-nc-sa-3.0", # Replace with appropriate license from allowed list + annotations_creators="human-annotated", + dialect=[], + modalities=["audio"], + sample_creation="found", + bibtex_citation="""@inproceedings{piczak2015dataset, + title = {{ESC}: {Dataset} for {Environmental Sound Classification}}, + author = {Piczak, Karol J.}, + booktitle = {Proceedings of the 23rd {Annual ACM Conference} on {Multimedia}}, + date = {2015-10-13}, + url = {http://dl.acm.org/citation.cfm?doid=2733373.2806390}, + doi = {10.1145/2733373.2806390}, + location = {{Brisbane, Australia}}, + isbn = {978-1-4503-3459-4}, + publisher = {{ACM Press}}, + pages = {1015--1018} + }""", + descriptive_stats={ + "n_samples": {"train": 2000}, # Need actual number + }, + ) + + audio_column_name: str = "audio" + label_column_name: str = "target" + samples_per_label: int = 50 + + def get_candidate_labels(self) -> list[str]: + """Return the text candidates for zeroshot classification""" + return [ + "This is a sound of dog", + "This is a sound of rooster", + "This is a sound of pig", + "This is a sound of cow", + "This is a sound of frog", + "This is a sound of cat", + "This is a sound of hen", + "This is a sound of insects", + "This is a sound of sheep", + "This is a sound of crow", + "This is a sound of rain", + "This is a sound of sea_waves", + "This is a sound of crackling_fire", + "This is a sound of crickets", + "This is a sound of chirping_birds", + "This is a sound of water_drops", + "This is a sound of wind", + "This is a sound of pouring_water", + "This is a sound of toilet_flush", + "This is a sound of thunderstorm", + "This is a sound of crying_baby", + "This is a sound of sneezing", + "This is a sound of clapping", + "This is a sound of breathing", + "This is a sound of coughing", + "This is a sound of footsteps", + "This is a sound of laughing", + "This is a sound of brushing_teeth", + "This is a sound of snoring", + "This is a sound of drinking_sipping", + "This is a sound of door_wood_knock", + "This is a sound of mouse_click", + "This is a sound of keyboard_typing", + "This is a sound of door_wood_creaks", + "This is a sound of can_opening", + "This is a sound of washing_machine", + "This is a sound of vacuum_cleaner", + "This is a sound of clock_alarm", + "This is a sound of clock_tick", + "This is a sound of glass_breaking", + "This is a sound of helicopter", + "This is a sound of chainsaw", + "This is a sound of siren", + "This is a sound of car_horn", + "This is a sound of engine", + "This is a sound of train", + "This is a sound of church_bells", + "This is a sound of airplane", + "This is a sound of fireworks", + "This is a sound of hand_saw", + ] diff --git a/mteb/tasks/Audio/__init__.py b/mteb/tasks/Audio/__init__.py index 9f32ec6a10..8194faaf1c 100644 --- a/mteb/tasks/Audio/__init__.py +++ b/mteb/tasks/Audio/__init__.py @@ -1,4 +1,5 @@ from __future__ import annotations from .AudioMultilabelClassification import * +from .AudioZeroshotClassification import * from .Clustering import * diff --git a/tests/test_benchmark/mock_models.py b/tests/test_benchmark/mock_models.py index b65b3e6559..c6c40759b9 100644 --- a/tests/test_benchmark/mock_models.py +++ b/tests/test_benchmark/mock_models.py @@ -90,7 +90,7 @@ def get_text_embeddings( texts, **kwargs, ): - return torch.randn(len(texts), self.embedding_dim) + return np.random.rand(len(texts), self.embedding_dim) def calculate_probs( self, text_embeddings: np.ndarray, audio_embeddings: np.ndarray diff --git a/tests/test_benchmark/mock_tasks.py b/tests/test_benchmark/mock_tasks.py index 1ee4722015..15847e9aa4 100644 --- a/tests/test_benchmark/mock_tasks.py +++ b/tests/test_benchmark/mock_tasks.py @@ -24,6 +24,9 @@ from mteb.abstasks.Audio.AbsTaskAudioMultilabelClassification import ( AbsTaskAudioMultilabelClassification, ) +from mteb.abstasks.Audio.AbsTaskAudioZeroshotClassification import ( + AbsTaskAudioZeroshotClassification, +) from mteb.abstasks.Image.AbsTaskAny2AnyMultiChoice import AbsTaskAny2AnyMultiChoice from mteb.abstasks.Image.AbsTaskAny2AnyRetrieval import AbsTaskAny2AnyRetrieval from mteb.abstasks.Image.AbsTaskAny2TextMultipleChoice import ( @@ -582,6 +585,58 @@ def load_data(self, **kwargs): self.data_loaded = True +class MockAudioZeroshotClassificationTask(AbsTaskAudioZeroshotClassification): + audio_column_name: str = "audio" + label_column_name: str = "label" + + expected_stats = { + "test": { + "num_samples": 2, + "total_duration": 2.0, # 2 samples * 1s each + "min_duration": 1.0, + "avg_duration": 1.0, + "max_duration": 1.0, + "sample_rate": 16000, + "unique_labels": 2, + "labels": {"0": {"count": 1}, "1": {"count": 1}}, + } + } + + metadata = TaskMetadata( + type="AudioZeroshotClassification", + name="MockAudioZeroshotClassification", + main_score="accuracy", + **general_args, + ) + + def load_data(self, **kwargs): + # Create mock audio data as numpy arrays + mock_audio = [ + { + "array": np.random.rand(16000).astype(np.float32), # 1s audio + "sampling_rate": 16000, + } + for _ in range(2) + ] + labels = np.array([0, 1]) # Convert labels to numpy array + + self.dataset = DatasetDict( + { + "test": Dataset.from_dict( + { + "audio": mock_audio, + "label": labels, + } + ), + } + ) + self.data_loaded = True + + def get_candidate_labels(self) -> list[str]: + """Return the text candidates for zeroshot classification""" + return ["This is sound type 0", "This is sound type 1"] + + class MockMultilingualClusteringTask(AbsTaskClustering, MultilingualTask): expected_stats = { "test": { diff --git a/tests/test_benchmark/task_grid.py b/tests/test_benchmark/task_grid.py index a408fd3f06..127f27bff0 100644 --- a/tests/test_benchmark/task_grid.py +++ b/tests/test_benchmark/task_grid.py @@ -16,6 +16,7 @@ MockAny2AnyRetrievalT2ITask, MockAudioClusteringTask, MockAudioMultilabelClassificationTask, + MockAudioZeroshotClassificationTask, MockBitextMiningTask, MockClassificationTask, MockClusteringFastTask, @@ -140,6 +141,7 @@ MOCK_MAEB_TASK_GRID = [ MockAudioClusteringTask(), MockAudioMultilabelClassificationTask(), + MockAudioZeroshotClassificationTask(), ]