diff --git a/mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py b/mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py index 5fd2e208a0..739dffccf6 100644 --- a/mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py +++ b/mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py @@ -8,10 +8,9 @@ from time import time from typing import Any +import torch import tqdm from datasets import Features, Value, load_dataset -import torch -import torchaudio from ...evaluation.evaluators import Any2AnyRetrievalEvaluator from ..AbsTask import AbsTask, ScoresDict @@ -43,9 +42,15 @@ def __init__( if prefix: query_file = prefix + "-" + query_file qrels_folder = prefix + "-" + qrels_folder - self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file - self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file - self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None + self.corpus_file = ( + os.path.join(data_folder, corpus_file) if data_folder else corpus_file + ) + self.query_file = ( + os.path.join(data_folder, query_file) if data_folder else query_file + ) + self.qrels_folder = ( + os.path.join(data_folder, qrels_folder) if data_folder else None + ) self.qrels_file = qrels_file self.streaming = streaming self.keep_in_memory = keep_in_memory @@ -57,7 +62,9 @@ def check(fIn: str, ext: str): if not fIn.endswith(ext): raise ValueError(f"File {fIn} must have extension {ext}") - def load(self, split="test") -> tuple[ + def load( + self, split="test" + ) -> tuple[ dict[str, dict[str, str | torch.Tensor]], dict[str, dict[str, str | torch.Tensor]], dict[str, dict[str, int]], @@ -71,7 +78,9 @@ def load(self, split="test") -> tuple[ if not len(self.corpus): logger.info("Loading Corpus...") self._load_corpus() - logger.info("Loaded %d Documents for %s split.", len(self.corpus), split.upper()) + logger.info( + "Loaded %d Documents for %s split.", len(self.corpus), split.upper() + ) logger.info("Doc Example: %s", self.corpus[0]) if not len(self.queries): @@ -80,8 +89,10 @@ def load(self, split="test") -> tuple[ self._load_qrels(split) qrels_dict = defaultdict(dict) + def qrels_dict_init(row): qrels_dict[row["query-id"]][row["corpus-id"]] = int(row["score"]) + self.qrels.map(qrels_dict_init) self.qrels = qrels_dict self.queries = self.queries.filter(lambda x: x["id"] in self.qrels) @@ -150,18 +161,21 @@ def _load_qrels(self, split): ) if "Q0" in qrels_ds.column_names: qrels_ds = qrels_ds.remove_columns("Q0") - features = Features({ - "query-id": Value("string"), - "corpus-id": Value("string"), - "score": Value("float"), - }) - qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"]).cast(features) + features = Features( + { + "query-id": Value("string"), + "corpus-id": Value("string"), + "score": Value("float"), + } + ) + qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"]).cast( + features + ) self.qrels = qrels_ds class AbsTaskAny2AnyRetrieval(AbsTask): - """ - Abstract class for audio-text retrieval experiments. + """Abstract class for audio-text retrieval experiments. Child-classes must implement: - self.corpus: dict[str, dict[str, str]] @@ -190,7 +204,9 @@ def load_data(self, **kwargs): keep_in_memory=False, ).load(split=split) self.corpus[split], self.queries[split], self.relevant_docs[split] = ( - corpus, queries, qrels + corpus, + queries, + qrels, ) self.data_loaded = True @@ -249,10 +265,16 @@ def _evaluate_subset( if top_k is not None: for qid in list(results.keys()): doc_ids = set( - sorted(results[qid], key=lambda x: results[qid][x], reverse=True)[:top_k] + sorted( + results[qid], key=lambda x: results[qid][x], reverse=True + )[:top_k] ) - results[qid] = {k: v for k, v in results[qid].items() if k in doc_ids} - predictions_path = output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json" + results[qid] = { + k: v for k, v in results[qid].items() if k in doc_ids + } + predictions_path = ( + output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json" + ) with open(predictions_path, "w") as f: json.dump(results, f) @@ -295,14 +317,20 @@ def _evaluate_subset( results[qid] = dict(sorted_docs) for qid, retrieved_docs in results.items(): expected_docs = relevant_docs[qid] - false_positives = [doc for doc in retrieved_docs if doc not in expected_docs] - false_negatives = [doc for doc in expected_docs if doc not in retrieved_docs] + false_positives = [ + doc for doc in retrieved_docs if doc not in expected_docs + ] + false_negatives = [ + doc for doc in expected_docs if doc not in retrieved_docs + ] if false_positives or false_negatives: errors[qid] = { "false_positives": false_positives, "false_negatives": false_negatives, } - errors_path = output_folder / f"{self.metadata.name}_{hf_subset}_errors.json" + errors_path = ( + output_folder / f"{self.metadata.name}_{hf_subset}_errors.json" + ) with open(errors_path, "w") as f: json.dump(errors, f) @@ -319,13 +347,17 @@ def _calculate_metrics_from_split( def calculate_metadata_metrics(self) -> None: self.load_data() all_details = {} - pbar_split = tqdm.tqdm(self.metadata_dict["eval_splits"], desc="Processing Splits...") + pbar_split = tqdm.tqdm( + self.metadata_dict["eval_splits"], desc="Processing Splits..." + ) for split in pbar_split: pbar_split.set_postfix_str(f"Split: {split}") logger.info(f"Processing metadata for split {split}") all_details[split] = {} if self.is_multilingual: - pbar_lang = tqdm.tqdm(self.relevant_docs.keys(), desc="Processing Languages...") + pbar_lang = tqdm.tqdm( + self.relevant_docs.keys(), desc="Processing Languages..." + ) for lang in pbar_lang: pbar_lang.set_postfix_str(f"Language: {lang}") logger.info(f"Processing metadata for language {lang}") @@ -358,7 +390,9 @@ def process_language(relevant_docs, queries, corpus, lang=None): logger.info(f"Average query length{language_description} is {query_len}") logger.info(f"Number of documents{language_description} is {num_documents}") logger.info(f"Number of queries{language_description} is {num_queries}") - logger.info(f"Average relevant docs per query{language_description} is {qrels_per_doc}") + logger.info( + f"Average relevant docs per query{language_description} is {qrels_per_doc}" + ) return { "average_document_length": doc_len, "average_query_length": query_len, diff --git a/mteb/abstasks/TaskMetadata.py b/mteb/abstasks/TaskMetadata.py index 24b3c9fa23..b425985041 100644 --- a/mteb/abstasks/TaskMetadata.py +++ b/mteb/abstasks/TaskMetadata.py @@ -59,6 +59,7 @@ "Activity recognition", "Tumor detection", "Duplicate Detection", + "Gender Clustering", ] TASK_DOMAIN = Literal[ @@ -101,6 +102,7 @@ "multiple", ] TASK_TYPE = Literal[ + "AudioClustering", "BitextMining", "Classification", "MultilabelClassification", @@ -137,6 +139,7 @@ "i2it", # image-to-image+text "t2it", # text-to-image+text "it2it", # image+text-to-image+text + "a2a", # audio-to-audio ] ANNOTATOR_TYPE = Literal[ diff --git a/mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py b/mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py index bd03af28b4..cf9e55f27c 100644 --- a/mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py @@ -12,8 +12,8 @@ import numpy as np import pytrec_eval import torch -from datasets import Dataset import torchaudio +from datasets import Dataset from torch.utils.data import DataLoader from mteb.encoder_interface import Encoder, PromptType @@ -33,8 +33,10 @@ logger = logging.getLogger(__name__) + # A default transform for audio; replace with a more meaningful transform as needed. -DEFAULT_AUDIO_TRANSFORM = lambda x: x +def DEFAULT_AUDIO_TRANSFORM(x): + return x class AudioDataset(torch.utils.data.Dataset): @@ -342,7 +344,9 @@ def evaluate( if qid == pid: results[qid].pop(pid) else: - logger.debug("Not ignoring identical query and document ids. Set ignore_identical_ids=True to ignore.") + logger.debug( + "Not ignoring identical query and document ids. Set ignore_identical_ids=True to ignore." + ) all_ndcgs, all_aps, all_recalls, all_precisions, all_cv_recalls = ( {}, @@ -404,7 +408,9 @@ def evaluate( _map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5) recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5) precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5) - cv_recall[f"CV_Recall@{k}"] = round(sum(cv_recall[f"CV_Recall@{k}"]) / len(scores), 5) + cv_recall[f"CV_Recall@{k}"] = round( + sum(cv_recall[f"CV_Recall@{k}"]) / len(scores), 5 + ) naucs = Any2AnyRetrievalEvaluator.evaluate_abstention( results, @@ -427,7 +433,13 @@ def evaluate_custom( metric_scores = recall_cap(qrels, results, k_values, output_type) elif metric.lower() in ["hole", "hole@k"]: metric_scores = hole(qrels, results, k_values, output_type) - elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]: + elif metric.lower() in [ + "acc", + "top_k_acc", + "accuracy", + "accuracy@k", + "top_k_accuracy", + ]: metric_scores = top_k_accuracy(qrels, results, k_values, output_type) naucs = Any2AnyRetrievalEvaluator.evaluate_abstention(results, metric_scores) metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()} @@ -460,5 +472,9 @@ def calculate_cv_style_recall( cv_recalls = {} for query_id, relevant_docs in qrels.items(): retrieved_docs = list(results.get(query_id, {}).keys())[:k] - cv_recalls[query_id] = 1.0 if any(doc_id in relevant_docs for doc_id in retrieved_docs) else 0.0 + cv_recalls[query_id] = ( + 1.0 + if any(doc_id in relevant_docs for doc_id in retrieved_docs) + else 0.0 + ) return cv_recalls diff --git a/mteb/models/overview.py b/mteb/models/overview.py index ef41f79088..05278fa664 100644 --- a/mteb/models/overview.py +++ b/mteb/models/overview.py @@ -71,6 +71,7 @@ vlm2vec_models, voyage_models, voyage_v, + wav2vec_models, ) logger = logging.getLogger(__name__) @@ -136,6 +137,7 @@ uae_models, voyage_models, fa_models, + wav2vec_models, ] MODEL_REGISTRY = {} diff --git a/mteb/models/wav2vec_models.py b/mteb/models/wav2vec_models.py new file mode 100644 index 0000000000..10ce3cc1a9 --- /dev/null +++ b/mteb/models/wav2vec_models.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from functools import partial + +import numpy as np +import torch +from datasets import Audio +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model + +from mteb.encoder_interface import AudioEncoder, PromptType +from mteb.model_meta import ModelMeta + + +class Wav2vec2Wrapper(AudioEncoder): + def __init__( + self, + device: str | None = None, + model_name="facebook/wav2vec2-base", + model_revision=None, + **kwargs, + ): + self.device = device + self.model = Wav2Vec2Model.from_pretrained(model_name, revision=model_revision) + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + model_name, revision=model_revision + ) + self.embed_dim = self.model.config.hidden_size + + if device: + self.model = self.model.to(device) + + def get_audio_embeddings( + self, audio_files: list[Audio] | Audio, batch_size: int = 32, **kwargs + ) -> np.ndarray: + if not isinstance(audio_files, list): + audio_files = [audio_files] + + all_embeddings = [] + + for i in range(0, len(audio_files), batch_size): + batch = audio_files[i : i + batch_size] + + audio_data = [file["array"] for file in batch] + sampling_rates = [file["sampling_rate"] for file in batch] + + # Preprocess batch + inputs = self.feature_extractor( + audio_data, + sampling_rate=sampling_rates[0], + padding=True, + return_tensors="pt", + ) + + if self.device: + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get embeddings + with torch.no_grad(): + outputs = self.model( + input_values=inputs["input_values"], + output_hidden_states=True, + return_dict=True, + ) + + hidden_states = outputs.hidden_states[-1] + batch_embeddings = hidden_states.mean(dim=1).cpu().numpy() + all_embeddings.append(batch_embeddings) + + return np.vstack(all_embeddings) + + def encode( + self, + audio_files: list[Audio], + *, + task_name: str, + prompt_type: PromptType | None = None, + batch_size: int = 32, + **kwargs, + ) -> np.ndarray: + return self.get_audio_embeddings(audio_files, batch_size=batch_size, **kwargs) + + +wav2vec2_base = ModelMeta( + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-base", + model_revision="0b5b8e868dd84f03fd87d01f9c4ff0f080fecfe8", + ), + name="facebook/wav2vec2-base", + languages=["en"], + open_weights=True, + revision="0b5b8e868dd84f03fd87d01f9c4ff0f080fecfe8", + release_date="2020-10-26", + max_tokens=float("inf"), + n_parameters=95_000_000, + memory_usage_mb=362, + embed_dim=768, + license="Apache-2.0", + reference="https://huggingface.co/facebook/wav2vec2-base", + similarity_fn_name="cosine", + framework=["PyTorch"], + use_instructions=False, + public_training_code=None, + public_training_data=None, + training_datasets=None, + modalities=["audio"], +) + + +wav2vec2_base_960h = ModelMeta( + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-base-960h", + model_revision="22aad52d435eb6dbaf354bdad9b0da84ce7d6156", + ), + name="facebook/wav2vec2-base-960h", + languages=["en"], + open_weights=True, + revision="22aad52d435eb6dbaf354bdad9b0da84ce7d6156", + release_date="2020-10-26", + max_tokens=float("inf"), + n_parameters=95_000_000, + memory_usage_mb=360, + embed_dim=768, + license="Apache-2.0", + reference="https://huggingface.co/facebook/wav2vec2-base-960h", + similarity_fn_name="cosine", + framework=["PyTorch"], + use_instructions=False, + public_training_code=None, + public_training_data=None, + training_datasets=None, + modalities=["audio"], +) + + +wav2vec2_large = ModelMeta( + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-large", + model_revision="312b2410566b698c7a649068d413b2067848bd75", + ), + name="facebook/wav2vec2-large", + languages=["en"], + open_weights=True, + revision="312b2410566b698c7a649068d413b2067848bd75", + release_date="2020-10-26", + max_tokens=float("inf"), + n_parameters=317_000_000, + memory_usage_mb=1_209, + embed_dim=1_024, + license="Apache-2.0", + reference="https://huggingface.co/facebook/wav2vec2-large", + similarity_fn_name="cosine", + framework=["PyTorch"], + use_instructions=False, + public_training_code=None, + public_training_data=None, + training_datasets=None, + modalities=["audio"], +) + + +wav2vec2_large_xlsr_53 = ModelMeta( + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-large-xlsr-53", + model_revision="c3f9d884181a224a6ac87bf8885c84d1cff3384f", + ), + name="facebook/wav2vec2-large-xlsr-53", + languages=["en"], + open_weights=True, + revision="c3f9d884181a224a6ac87bf8885c84d1cff3384f", + release_date="2020-10-26", + max_tokens=float("inf"), + n_parameters=317_000_000, + memory_usage_mb=1_209, + embed_dim=1_024, + license="Apache-2.0", + reference="https://huggingface.co/facebook/wav2vec2-large-xlsr-53", + similarity_fn_name="cosine", + framework=["PyTorch"], + use_instructions=False, + public_training_code=None, + public_training_data=None, + training_datasets=None, + modalities=["audio"], +) + + +wav2vec2_lv_60_espeak_cv_ft = ModelMeta( + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-lv-60-espeak-cv-ft", + model_revision="ae45363bf3413b374fecd9dc8bc1df0e24c3b7f4", + ), + name="facebook/wav2vec2-lv-60-espeak-cv-ft", + languages=["en"], + open_weights=True, + revision="ae45363bf3413b374fecd9dc8bc1df0e24c3b7f4", + release_date="2020-10-26", + max_tokens=float("inf"), + n_parameters=317_000_000, + memory_usage_mb=1_209, + embed_dim=1_024, + license="Apache-2.0", + reference="https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft", + similarity_fn_name="cosine", + framework=["PyTorch"], + use_instructions=False, + public_training_code=None, + public_training_data=None, + training_datasets=None, + modalities=["audio"], +) diff --git a/mteb/tasks/Audio/Clustering/__init__.py b/mteb/tasks/Audio/Clustering/__init__.py new file mode 100644 index 0000000000..58bc6a22d8 --- /dev/null +++ b/mteb/tasks/Audio/Clustering/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .eng.VoiceGender import * diff --git a/mteb/tasks/Audio/Clustering/eng/VoiceGender.py b/mteb/tasks/Audio/Clustering/eng/VoiceGender.py new file mode 100644 index 0000000000..b24cb3259a --- /dev/null +++ b/mteb/tasks/Audio/Clustering/eng/VoiceGender.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from mteb.abstasks.Audio.AbsTaskAudioClustering import AbsTaskAudioClustering +from mteb.abstasks.TaskMetadata import TaskMetadata + + +class VoiceGenderClustering(AbsTaskAudioClustering): + label_column_name: str = "label" + metadata = TaskMetadata( + name="VoiceGenderClustering", + description="Clustering audio recordings based on gender (male vs female).", + reference="https://huggingface.co/datasets/mmn3690/voice-gender-clustering", + dataset={ + "path": "mmn3690/voice-gender-clustering", + "revision": "1b202ea7bcd0abd5283e628248803e1569257c80", + }, + type="AudioClustering", + category="a2a", + eval_splits=["train"], + eval_langs=["eng-Latn"], + main_score="clustering_accuracy", + date=("2024-01-01", "2024-12-31"), + domains=["Spoken"], + task_subtypes=["Gender Clustering"], + license="not specified", + annotations_creators="derived", + dialect=[], + modalities=["audio"], + sample_creation="found", + bibtex_citation="""@InProceedings{Chung18b, + author = "Chung, J.~S. and Nagrani, A. and Zisserman, A.", + title = "VoxCeleb2: Deep Speaker Recognition", + booktitle = "INTERSPEECH", + year = "2018 + }""", + ) diff --git a/mteb/tasks/Audio/Clustering/eng/__init__.py b/mteb/tasks/Audio/Clustering/eng/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mteb/tasks/Audio/__init__.py b/mteb/tasks/Audio/__init__.py new file mode 100644 index 0000000000..9777bd4544 --- /dev/null +++ b/mteb/tasks/Audio/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .Clustering import * diff --git a/mteb/tasks/__init__.py b/mteb/tasks/__init__.py index e00f091174..25a8f78503 100644 --- a/mteb/tasks/__init__.py +++ b/mteb/tasks/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from .aggregated_tasks import * +from .Audio.Clustering import * from .BitextMining import * from .Classification import * from .Clustering import * diff --git a/pyproject.toml b/pyproject.toml index 1086a34e73..21bd9ee4d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "eval_type_backport>=0.0.0", "polars>=0.20.22", "torchvision>0.0.0", + "torchaudio>=2.6.0" ] diff --git a/tests/test_benchmark/mock_models.py b/tests/test_benchmark/mock_models.py index dece34e703..eb1764f5e0 100644 --- a/tests/test_benchmark/mock_models.py +++ b/tests/test_benchmark/mock_models.py @@ -72,28 +72,24 @@ def __init__(self): def get_audio_embeddings( self, - audio, # list + audio, # list **kwargs, ) -> np.ndarray: - return np.random.rand(len(audio), self.embedding_dim) def get_text_embeddings( self, - texts, # list + texts, # list **kwargs, - ) -> np.ndarray: + ) -> np.ndarray: pass def calculate_probs( - self, - text_embeddings: np.ndarray, - audio_embeddings: np.ndarray + self, text_embeddings: np.ndarray, audio_embeddings: np.ndarray ) -> np.ndarray: pass - class MockSentenceTransformer(SentenceTransformer): """A mock implementation of the SentenceTransformer intended to implement just the encode, method using the same arguments.""" diff --git a/tests/test_benchmark/mock_tasks.py b/tests/test_benchmark/mock_tasks.py index 8863c8c4d0..f2cc0aa563 100644 --- a/tests/test_benchmark/mock_tasks.py +++ b/tests/test_benchmark/mock_tasks.py @@ -10,7 +10,6 @@ from mteb.abstasks.AbsTaskBitextMining import AbsTaskBitextMining from mteb.abstasks.AbsTaskClassification import AbsTaskClassification from mteb.abstasks.AbsTaskClustering import AbsTaskClustering -from mteb.abstasks.Audio.AbsTaskAudioClustering import AbsTaskAudioClustering from mteb.abstasks.AbsTaskClusteringFast import AbsTaskClusteringFast from mteb.abstasks.AbsTaskInstructionRetrieval import AbsTaskInstructionRetrieval from mteb.abstasks.AbsTaskMultilabelClassification import ( @@ -21,6 +20,7 @@ from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval from mteb.abstasks.AbsTaskSTS import AbsTaskSTS from mteb.abstasks.AbsTaskSummarization import AbsTaskSummarization +from mteb.abstasks.Audio.AbsTaskAudioClustering import AbsTaskAudioClustering from mteb.abstasks.Image.AbsTaskAny2AnyMultiChoice import AbsTaskAny2AnyMultiChoice from mteb.abstasks.Image.AbsTaskAny2AnyRetrieval import AbsTaskAny2AnyRetrieval from mteb.abstasks.Image.AbsTaskAny2TextMultipleChoice import ( @@ -478,17 +478,18 @@ def load_data(self, **kwargs): ) self.data_loaded = True + class MockAudioClusteringTask(AbsTaskAudioClustering): expected_stats = { "test": { - "num_samples": 3, + "num_samples": 3, "number_of_samples": 3, "min_audio_length": 16000, # sr = 16000 - "average_audio_length": 16000, # 1s - "max_audio_length": 16000, # 1s + "average_audio_length": 16000, # 1s + "max_audio_length": 16000, # 1s "unique_audios": 3, "min_labels_per_audio": 1, - "average_labels_per_audio": 1.0, + "average_labels_per_audio": 1.0, "max_labels_per_audio": 1, "unique_labels": 3, "labels": {"0": {"count": 1}, "1": {"count": 1}, "2": {"count": 1}}, @@ -499,26 +500,26 @@ class MockAudioClusteringTask(AbsTaskAudioClustering): type="Clustering", name="MockAudioClusteringTask", main_score="v_measure", - **general_args, + **general_args, ) def load_data(self, **kwargs): mock_audio = [ { "array": np.random.rand(16000), # 1s - "sampling_rate": 16000 - } for _ in range(3) + "sampling_rate": 16000, + } + for _ in range(3) ] - - - labels = [0, 1, 2] + + labels = [0, 1, 2] self.dataset = DatasetDict( { "test": Dataset.from_dict( { - "audio": mock_audio, - "labels": labels, + "audio": mock_audio, + "labels": labels, } ), } diff --git a/tests/test_benchmark/task_grid.py b/tests/test_benchmark/task_grid.py index 47c4953eb6..797352977e 100644 --- a/tests/test_benchmark/task_grid.py +++ b/tests/test_benchmark/task_grid.py @@ -14,11 +14,11 @@ from .mock_tasks import ( MockAny2AnyRetrievalI2TTask, MockAny2AnyRetrievalT2ITask, + MockAudioClusteringTask, MockBitextMiningTask, MockClassificationTask, MockClusteringFastTask, MockClusteringTask, - MockAudioClusteringTask, MockImageClassificationKNNPTTask, MockImageClassificationKNNTask, MockImageClassificationTask, @@ -136,10 +136,6 @@ MockMultilingualImageMultilabelClassificationTask(), ] -MOCK_MAEB_TASK_GRID = [ - MockAudioClusteringTask() -] - MOCK_MIEB_TASK_GRID_AS_STRING = [ t.metadata.name if isinstance(t, AbsTask) else t for t in MOCK_MIEB_TASK_GRID ] @@ -147,3 +143,13 @@ MOCK_MIEB_TASK_REGISTRY = { task.metadata.name: type(task) for task in MOCK_MIEB_TASK_GRID } + +MOCK_MAEB_TASK_GRID = [MockAudioClusteringTask()] + +MOCK_MAEB_TASK_GRID_AS_STRING = [ + t.metadata.name if isinstance(t, AbsTask) else t for t in MOCK_MAEB_TASK_GRID +] + +MOCK_MAEB_TASK_REGISTRY = { + task.metadata.name: type(task) for task in MOCK_MAEB_TASK_GRID +} diff --git a/tests/test_tasks/test_all_abstasks.py b/tests/test_tasks/test_all_abstasks.py index 7a87914f0a..16ff888813 100644 --- a/tests/test_tasks/test_all_abstasks.py +++ b/tests/test_tasks/test_all_abstasks.py @@ -20,13 +20,18 @@ from mteb.overview import TASKS_REGISTRY from ..test_benchmark.task_grid import ( + MOCK_MAEB_TASK_GRID_AS_STRING, MOCK_MIEB_TASK_GRID_AS_STRING, MOCK_TASK_TEST_GRID_AS_STRING, ) logging.basicConfig(level=logging.INFO) -ALL_MOCK_TASKS = MOCK_TASK_TEST_GRID_AS_STRING + MOCK_MIEB_TASK_GRID_AS_STRING +ALL_MOCK_TASKS = ( + MOCK_TASK_TEST_GRID_AS_STRING + + MOCK_MIEB_TASK_GRID_AS_STRING + + MOCK_MAEB_TASK_GRID_AS_STRING +) tasks = [t for t in MTEB().tasks_cls if t.metadata.name not in ALL_MOCK_TASKS] @@ -101,6 +106,7 @@ def test_dataset_availability(): for t in tasks if t.metadata.name not in MOCK_TASK_TEST_GRID_AS_STRING if t.metadata.name not in MOCK_MIEB_TASK_GRID_AS_STRING + if t.metadata.name not in MOCK_MAEB_TASK_GRID_AS_STRING and t.metadata.name != "AfriSentiLangClassification" # HOTFIX: Issue#1777. Remove this line when issue is resolved. ] diff --git a/tests/test_tasks/test_maeb_datasets.py b/tests/test_tasks/test_maeb_datasets.py index fdd4b2f76d..c1a1b21a8b 100644 --- a/tests/test_tasks/test_maeb_datasets.py +++ b/tests/test_tasks/test_maeb_datasets.py @@ -21,4 +21,4 @@ def test_benchmark_audio_encoder(task: str | AbsTask, model: mteb.Encoder): """Test that a task can be fetched and run""" eval = MTEB(tasks=[task]) - eval.run(model, output_folder="tests/results", overwrite_results=True) \ No newline at end of file + eval.run(model, output_folder="tests/results", overwrite_results=True) diff --git a/tests/test_tasks/test_mieb_datasets.py b/tests/test_tasks/test_mieb_datasets.py index 431440d9bd..26e60931ec 100644 --- a/tests/test_tasks/test_mieb_datasets.py +++ b/tests/test_tasks/test_mieb_datasets.py @@ -21,4 +21,4 @@ def test_benchmark_sentence_transformer(task: str | AbsTask, model: mteb.Encoder): """Test that a task can be fetched and run""" eval = MTEB(tasks=[task]) - eval.run(model, output_folder="tests/results", overwrite_results=True) \ No newline at end of file + eval.run(model, output_folder="tests/results", overwrite_results=True)