diff --git a/mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py b/mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py index 5fd2e208a0..739dffccf6 100644 --- a/mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py +++ b/mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py @@ -8,10 +8,9 @@ from time import time from typing import Any +import torch import tqdm from datasets import Features, Value, load_dataset -import torch -import torchaudio from ...evaluation.evaluators import Any2AnyRetrievalEvaluator from ..AbsTask import AbsTask, ScoresDict @@ -43,9 +42,15 @@ def __init__( if prefix: query_file = prefix + "-" + query_file qrels_folder = prefix + "-" + qrels_folder - self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file - self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file - self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None + self.corpus_file = ( + os.path.join(data_folder, corpus_file) if data_folder else corpus_file + ) + self.query_file = ( + os.path.join(data_folder, query_file) if data_folder else query_file + ) + self.qrels_folder = ( + os.path.join(data_folder, qrels_folder) if data_folder else None + ) self.qrels_file = qrels_file self.streaming = streaming self.keep_in_memory = keep_in_memory @@ -57,7 +62,9 @@ def check(fIn: str, ext: str): if not fIn.endswith(ext): raise ValueError(f"File {fIn} must have extension {ext}") - def load(self, split="test") -> tuple[ + def load( + self, split="test" + ) -> tuple[ dict[str, dict[str, str | torch.Tensor]], dict[str, dict[str, str | torch.Tensor]], dict[str, dict[str, int]], @@ -71,7 +78,9 @@ def load(self, split="test") -> tuple[ if not len(self.corpus): logger.info("Loading Corpus...") self._load_corpus() - logger.info("Loaded %d Documents for %s split.", len(self.corpus), split.upper()) + logger.info( + "Loaded %d Documents for %s split.", len(self.corpus), split.upper() + ) logger.info("Doc Example: %s", self.corpus[0]) if not len(self.queries): @@ -80,8 +89,10 @@ def load(self, split="test") -> tuple[ self._load_qrels(split) qrels_dict = defaultdict(dict) + def qrels_dict_init(row): qrels_dict[row["query-id"]][row["corpus-id"]] = int(row["score"]) + self.qrels.map(qrels_dict_init) self.qrels = qrels_dict self.queries = self.queries.filter(lambda x: x["id"] in self.qrels) @@ -150,18 +161,21 @@ def _load_qrels(self, split): ) if "Q0" in qrels_ds.column_names: qrels_ds = qrels_ds.remove_columns("Q0") - features = Features({ - "query-id": Value("string"), - "corpus-id": Value("string"), - "score": Value("float"), - }) - qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"]).cast(features) + features = Features( + { + "query-id": Value("string"), + "corpus-id": Value("string"), + "score": Value("float"), + } + ) + qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"]).cast( + features + ) self.qrels = qrels_ds class AbsTaskAny2AnyRetrieval(AbsTask): - """ - Abstract class for audio-text retrieval experiments. + """Abstract class for audio-text retrieval experiments. Child-classes must implement: - self.corpus: dict[str, dict[str, str]] @@ -190,7 +204,9 @@ def load_data(self, **kwargs): keep_in_memory=False, ).load(split=split) self.corpus[split], self.queries[split], self.relevant_docs[split] = ( - corpus, queries, qrels + corpus, + queries, + qrels, ) self.data_loaded = True @@ -249,10 +265,16 @@ def _evaluate_subset( if top_k is not None: for qid in list(results.keys()): doc_ids = set( - sorted(results[qid], key=lambda x: results[qid][x], reverse=True)[:top_k] + sorted( + results[qid], key=lambda x: results[qid][x], reverse=True + )[:top_k] ) - results[qid] = {k: v for k, v in results[qid].items() if k in doc_ids} - predictions_path = output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json" + results[qid] = { + k: v for k, v in results[qid].items() if k in doc_ids + } + predictions_path = ( + output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json" + ) with open(predictions_path, "w") as f: json.dump(results, f) @@ -295,14 +317,20 @@ def _evaluate_subset( results[qid] = dict(sorted_docs) for qid, retrieved_docs in results.items(): expected_docs = relevant_docs[qid] - false_positives = [doc for doc in retrieved_docs if doc not in expected_docs] - false_negatives = [doc for doc in expected_docs if doc not in retrieved_docs] + false_positives = [ + doc for doc in retrieved_docs if doc not in expected_docs + ] + false_negatives = [ + doc for doc in expected_docs if doc not in retrieved_docs + ] if false_positives or false_negatives: errors[qid] = { "false_positives": false_positives, "false_negatives": false_negatives, } - errors_path = output_folder / f"{self.metadata.name}_{hf_subset}_errors.json" + errors_path = ( + output_folder / f"{self.metadata.name}_{hf_subset}_errors.json" + ) with open(errors_path, "w") as f: json.dump(errors, f) @@ -319,13 +347,17 @@ def _calculate_metrics_from_split( def calculate_metadata_metrics(self) -> None: self.load_data() all_details = {} - pbar_split = tqdm.tqdm(self.metadata_dict["eval_splits"], desc="Processing Splits...") + pbar_split = tqdm.tqdm( + self.metadata_dict["eval_splits"], desc="Processing Splits..." + ) for split in pbar_split: pbar_split.set_postfix_str(f"Split: {split}") logger.info(f"Processing metadata for split {split}") all_details[split] = {} if self.is_multilingual: - pbar_lang = tqdm.tqdm(self.relevant_docs.keys(), desc="Processing Languages...") + pbar_lang = tqdm.tqdm( + self.relevant_docs.keys(), desc="Processing Languages..." + ) for lang in pbar_lang: pbar_lang.set_postfix_str(f"Language: {lang}") logger.info(f"Processing metadata for language {lang}") @@ -358,7 +390,9 @@ def process_language(relevant_docs, queries, corpus, lang=None): logger.info(f"Average query length{language_description} is {query_len}") logger.info(f"Number of documents{language_description} is {num_documents}") logger.info(f"Number of queries{language_description} is {num_queries}") - logger.info(f"Average relevant docs per query{language_description} is {qrels_per_doc}") + logger.info( + f"Average relevant docs per query{language_description} is {qrels_per_doc}" + ) return { "average_document_length": doc_len, "average_query_length": query_len, diff --git a/mteb/abstasks/TaskMetadata.py b/mteb/abstasks/TaskMetadata.py index 46558ca69a..fa6978e171 100644 --- a/mteb/abstasks/TaskMetadata.py +++ b/mteb/abstasks/TaskMetadata.py @@ -59,7 +59,7 @@ "Activity recognition", "Tumor detection", "Duplicate Detection", - "Voice Gender Clustering", + "Gender Clustering", "Voice Emotion Clustering", ] @@ -141,6 +141,7 @@ "i2it", # image-to-image+text "t2it", # text-to-image+text "it2it", # image+text-to-image+text + "a2a", # audio-to-audio ] ANNOTATOR_TYPE = Literal[ diff --git a/mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py b/mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py index bd03af28b4..cf9e55f27c 100644 --- a/mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py @@ -12,8 +12,8 @@ import numpy as np import pytrec_eval import torch -from datasets import Dataset import torchaudio +from datasets import Dataset from torch.utils.data import DataLoader from mteb.encoder_interface import Encoder, PromptType @@ -33,8 +33,10 @@ logger = logging.getLogger(__name__) + # A default transform for audio; replace with a more meaningful transform as needed. -DEFAULT_AUDIO_TRANSFORM = lambda x: x +def DEFAULT_AUDIO_TRANSFORM(x): + return x class AudioDataset(torch.utils.data.Dataset): @@ -342,7 +344,9 @@ def evaluate( if qid == pid: results[qid].pop(pid) else: - logger.debug("Not ignoring identical query and document ids. Set ignore_identical_ids=True to ignore.") + logger.debug( + "Not ignoring identical query and document ids. Set ignore_identical_ids=True to ignore." + ) all_ndcgs, all_aps, all_recalls, all_precisions, all_cv_recalls = ( {}, @@ -404,7 +408,9 @@ def evaluate( _map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5) recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5) precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5) - cv_recall[f"CV_Recall@{k}"] = round(sum(cv_recall[f"CV_Recall@{k}"]) / len(scores), 5) + cv_recall[f"CV_Recall@{k}"] = round( + sum(cv_recall[f"CV_Recall@{k}"]) / len(scores), 5 + ) naucs = Any2AnyRetrievalEvaluator.evaluate_abstention( results, @@ -427,7 +433,13 @@ def evaluate_custom( metric_scores = recall_cap(qrels, results, k_values, output_type) elif metric.lower() in ["hole", "hole@k"]: metric_scores = hole(qrels, results, k_values, output_type) - elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]: + elif metric.lower() in [ + "acc", + "top_k_acc", + "accuracy", + "accuracy@k", + "top_k_accuracy", + ]: metric_scores = top_k_accuracy(qrels, results, k_values, output_type) naucs = Any2AnyRetrievalEvaluator.evaluate_abstention(results, metric_scores) metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()} @@ -460,5 +472,9 @@ def calculate_cv_style_recall( cv_recalls = {} for query_id, relevant_docs in qrels.items(): retrieved_docs = list(results.get(query_id, {}).keys())[:k] - cv_recalls[query_id] = 1.0 if any(doc_id in relevant_docs for doc_id in retrieved_docs) else 0.0 + cv_recalls[query_id] = ( + 1.0 + if any(doc_id in relevant_docs for doc_id in retrieved_docs) + else 0.0 + ) return cv_recalls diff --git a/mteb/models/wav2vec_models.py b/mteb/models/wav2vec_models.py index aa14f6d095..8a9e3246da 100644 --- a/mteb/models/wav2vec_models.py +++ b/mteb/models/wav2vec_models.py @@ -1,62 +1,54 @@ +from __future__ import annotations + from functools import partial from mteb.models.wrapper import Wrapper from mteb.encoder_interface import PromptType, AudioEncoder import numpy as np import torch -from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor -from mteb.model_meta import ModelMeta from datasets import Audio +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model +from mteb.encoder_interface import AudioEncoder, PromptType +from mteb.model_meta import ModelMeta class Wav2vec2Wrapper(AudioEncoder): def __init__( - self, - model_name: str, - # revision: str, - device: str | None = None, - **kwargs + self, + device: str | None = None, + model_name="facebook/wav2vec2-base", + model_revision=None, + **kwargs, ): - super().__init__(device=device, **kwargs) - self.model_name = model_name - # self.model_revision = revision - - self.model = Wav2Vec2Model.from_pretrained( - self.model_name, - # revision=self.model_revision - ) + self.device = device + self.model = Wav2Vec2Model.from_pretrained(model_name, revision=model_revision) self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - self.model_name, - # revision=self.model_revision + model_name, revision=model_revision ) self.embed_dim = self.model.config.hidden_size if device: self.model = self.model.to(device) - print("Wav2vec initialized.") def get_audio_embeddings( - self, - audio_files: list[Audio] | Audio, - batch_size: int = 32, - **kwargs + self, audio_files: list[Audio] | Audio, batch_size: int = 32, **kwargs ) -> np.ndarray: - if not isinstance(audio_files, list): audio_files = [audio_files] all_embeddings = [] for i in range(0, len(audio_files), batch_size): - batch = audio_files[i:i + batch_size] + batch = audio_files[i : i + batch_size] - audio_data = [file['array'] for file in batch] - sampling_rates = [file['sampling_rate'] for file in batch] + audio_data = [file["array"] for file in batch] + sampling_rates = [file["sampling_rate"] for file in batch] # Preprocess batch inputs = self.feature_extractor( audio_data, sampling_rate=sampling_rates[0], padding=True, + return_tensors="pt" ) @@ -68,32 +60,37 @@ def get_audio_embeddings( outputs = self.model( input_values=inputs["input_values"], output_hidden_states=True, - return_dict=True + return_dict=True, ) - hidden_states = outputs.hidden_states[6] - print(hidden_states.shape) + hidden_states = outputs.hidden_states[-1] + batch_embeddings = hidden_states.mean(dim=1).cpu().numpy() all_embeddings.append(batch_embeddings) return np.vstack(all_embeddings) def encode( - self, - audio_files: list[Audio], - *, - task_name: str, - prompt_type: PromptType | None = None, - **kwargs + self, + audio_files: list[Audio], + *, + task_name: str, + prompt_type: PromptType | None = None, + batch_size: int = 32, + **kwargs, ) -> np.ndarray: - - return self.get_audio_embeddings(audio_files, **kwargs) + return self.get_audio_embeddings(audio_files, batch_size=batch_size, **kwargs) wav2vec2_base = ModelMeta( - loader=partial(Wav2vec2Wrapper, model_name="facebook/wav2vec2-base"), + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-base", + model_revision="0b5b8e868dd84f03fd87d01f9c4ff0f080fecfe8", + ), name="facebook/wav2vec2-base", - languages=["eng"], + languages=["en"], + open_weights=True, revision="0b5b8e868dd84f03fd87d01f9c4ff0f080fecfe8", release_date="2020-10-26", @@ -109,13 +106,19 @@ def encode( public_training_code=None, public_training_data=None, training_datasets=None, - modalities=["audio"] + modalities=["audio"], ) + wav2vec2_base_960h = ModelMeta( - loader=partial(Wav2vec2Wrapper, model_name="facebook/wav2vec2-base-960h"), + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-base-960h", + model_revision="22aad52d435eb6dbaf354bdad9b0da84ce7d6156", + ), name="facebook/wav2vec2-base-960h", - languages=["eng"], + languages=["en"], + open_weights=True, revision="22aad52d435eb6dbaf354bdad9b0da84ce7d6156", release_date="2020-10-26", @@ -131,13 +134,20 @@ def encode( public_training_code=None, public_training_data=None, training_datasets=None, - modalities=["audio"] + + modalities=["audio"], ) + wav2vec2_large = ModelMeta( - loader=partial(Wav2vec2Wrapper, model_name="facebook/wav2vec2-large"), + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-large", + model_revision="312b2410566b698c7a649068d413b2067848bd75", + ), name="facebook/wav2vec2-large", - languages=["eng"], + languages=["en"], + open_weights=True, revision="312b2410566b698c7a649068d413b2067848bd75", release_date="2020-10-26", @@ -153,13 +163,20 @@ def encode( public_training_code=None, public_training_data=None, training_datasets=None, - modalities=["audio"] + + modalities=["audio"], ) + wav2vec2_large_xlsr_53 = ModelMeta( - loader=partial(Wav2vec2Wrapper, model_name="facebook/wav2vec2-large-xlsr-53"), + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-large-xlsr-53", + model_revision="c3f9d884181a224a6ac87bf8885c84d1cff3384f", + ), name="facebook/wav2vec2-large-xlsr-53", - languages=["multilingual"], + languages=["en"], + open_weights=True, revision="c3f9d884181a224a6ac87bf8885c84d1cff3384f", release_date="2020-10-26", @@ -175,13 +192,19 @@ def encode( public_training_code=None, public_training_data=None, training_datasets=None, - modalities=["audio"] + modalities=["audio"], ) + wav2vec2_lv_60_espeak_cv_ft = ModelMeta( - loader=partial(Wav2vec2Wrapper, model_name="facebook/wav2vec2-lv-60-espeak-cv-ft"), + loader=partial( + Wav2vec2Wrapper, + model_name="facebook/wav2vec2-lv-60-espeak-cv-ft", + model_revision="ae45363bf3413b374fecd9dc8bc1df0e24c3b7f4", + ), name="facebook/wav2vec2-lv-60-espeak-cv-ft", - languages=["multilingual"], + languages=["en"], + open_weights=True, revision="ae45363bf3413b374fecd9dc8bc1df0e24c3b7f4", release_date="2020-10-26", @@ -197,7 +220,6 @@ def encode( public_training_code=None, public_training_data=None, training_datasets=None, - modalities=["audio"] + modalities=["audio"], ) -# print(f"wav2vec2_lv_60_espeak_cv_ft: {wav2vec2_lv_60_espeak_cv_ft.calculate_memory_usage_mb()}") diff --git a/mteb/tasks/Audio/Clustering/__init__.py b/mteb/tasks/Audio/Clustering/__init__.py index 6f3434ecff..58bc6a22d8 100644 --- a/mteb/tasks/Audio/Clustering/__init__.py +++ b/mteb/tasks/Audio/Clustering/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -from .eng.VoiceGender import * \ No newline at end of file +from .eng.VoiceGender import * diff --git a/mteb/tasks/Audio/Clustering/eng/VoiceGender.py b/mteb/tasks/Audio/Clustering/eng/VoiceGender.py index 763d8c8db8..ce65b0ba3a 100644 --- a/mteb/tasks/Audio/Clustering/eng/VoiceGender.py +++ b/mteb/tasks/Audio/Clustering/eng/VoiceGender.py @@ -1,10 +1,11 @@ +from __future__ import annotations + from mteb.abstasks.Audio.AbsTaskAudioClustering import AbsTaskAudioClustering from mteb.abstasks.TaskMetadata import TaskMetadata import mteb from mteb import MTEB - class VoiceGenderClustering(AbsTaskAudioClustering): label_column_name: str = "label" metadata = TaskMetadata( @@ -13,22 +14,30 @@ class VoiceGenderClustering(AbsTaskAudioClustering): reference="https://huggingface.co/datasets/mmn3690/voice-gender-clustering", dataset={ "path": "mmn3690/voice-gender-clustering", - "revision": "main", + "revision": "1b202ea7bcd0abd5283e628248803e1569257c80", + }, type="AudioClustering", category="a2a", eval_splits=["train"], eval_langs=["eng-Latn"], - main_score="nmi", + main_score="clustering_accuracy", date=("2024-01-01", "2024-12-31"), domains=["Spoken"], - task_subtypes=["Voice Gender Clustering"], + task_subtypes=["Gender Clustering"], license="not specified", annotations_creators="derived", dialect=[], modalities=["audio"], + sample_creation="found", + bibtex_citation="""@InProceedings{Chung18b, + author = "Chung, J.~S. and Nagrani, A. and Zisserman, A.", + title = "VoxCeleb2: Deep Speaker Recognition", + booktitle = "INTERSPEECH", + year = "2018 + }""", ) - + if __name__ == "__main__": #model_name = "microsoft/wavlm-base" model_name = "facebook/wav2vec2-base" diff --git a/mteb/tasks/Audio/__init__.py b/mteb/tasks/Audio/__init__.py index 4ae414ff9e..9777bd4544 100644 --- a/mteb/tasks/Audio/__init__.py +++ b/mteb/tasks/Audio/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -from .Clustering import * \ No newline at end of file +from .Clustering import * diff --git a/mteb/tasks/__init__.py b/mteb/tasks/__init__.py index e00f091174..25a8f78503 100644 --- a/mteb/tasks/__init__.py +++ b/mteb/tasks/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from .aggregated_tasks import * +from .Audio.Clustering import * from .BitextMining import * from .Classification import * from .Clustering import * diff --git a/pyproject.toml b/pyproject.toml index 1086a34e73..21bd9ee4d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "eval_type_backport>=0.0.0", "polars>=0.20.22", "torchvision>0.0.0", + "torchaudio>=2.6.0" ] diff --git a/tests/test_benchmark/mock_models.py b/tests/test_benchmark/mock_models.py index dece34e703..eb1764f5e0 100644 --- a/tests/test_benchmark/mock_models.py +++ b/tests/test_benchmark/mock_models.py @@ -72,28 +72,24 @@ def __init__(self): def get_audio_embeddings( self, - audio, # list + audio, # list **kwargs, ) -> np.ndarray: - return np.random.rand(len(audio), self.embedding_dim) def get_text_embeddings( self, - texts, # list + texts, # list **kwargs, - ) -> np.ndarray: + ) -> np.ndarray: pass def calculate_probs( - self, - text_embeddings: np.ndarray, - audio_embeddings: np.ndarray + self, text_embeddings: np.ndarray, audio_embeddings: np.ndarray ) -> np.ndarray: pass - class MockSentenceTransformer(SentenceTransformer): """A mock implementation of the SentenceTransformer intended to implement just the encode, method using the same arguments.""" diff --git a/tests/test_benchmark/mock_tasks.py b/tests/test_benchmark/mock_tasks.py index 8863c8c4d0..f2cc0aa563 100644 --- a/tests/test_benchmark/mock_tasks.py +++ b/tests/test_benchmark/mock_tasks.py @@ -10,7 +10,6 @@ from mteb.abstasks.AbsTaskBitextMining import AbsTaskBitextMining from mteb.abstasks.AbsTaskClassification import AbsTaskClassification from mteb.abstasks.AbsTaskClustering import AbsTaskClustering -from mteb.abstasks.Audio.AbsTaskAudioClustering import AbsTaskAudioClustering from mteb.abstasks.AbsTaskClusteringFast import AbsTaskClusteringFast from mteb.abstasks.AbsTaskInstructionRetrieval import AbsTaskInstructionRetrieval from mteb.abstasks.AbsTaskMultilabelClassification import ( @@ -21,6 +20,7 @@ from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval from mteb.abstasks.AbsTaskSTS import AbsTaskSTS from mteb.abstasks.AbsTaskSummarization import AbsTaskSummarization +from mteb.abstasks.Audio.AbsTaskAudioClustering import AbsTaskAudioClustering from mteb.abstasks.Image.AbsTaskAny2AnyMultiChoice import AbsTaskAny2AnyMultiChoice from mteb.abstasks.Image.AbsTaskAny2AnyRetrieval import AbsTaskAny2AnyRetrieval from mteb.abstasks.Image.AbsTaskAny2TextMultipleChoice import ( @@ -478,17 +478,18 @@ def load_data(self, **kwargs): ) self.data_loaded = True + class MockAudioClusteringTask(AbsTaskAudioClustering): expected_stats = { "test": { - "num_samples": 3, + "num_samples": 3, "number_of_samples": 3, "min_audio_length": 16000, # sr = 16000 - "average_audio_length": 16000, # 1s - "max_audio_length": 16000, # 1s + "average_audio_length": 16000, # 1s + "max_audio_length": 16000, # 1s "unique_audios": 3, "min_labels_per_audio": 1, - "average_labels_per_audio": 1.0, + "average_labels_per_audio": 1.0, "max_labels_per_audio": 1, "unique_labels": 3, "labels": {"0": {"count": 1}, "1": {"count": 1}, "2": {"count": 1}}, @@ -499,26 +500,26 @@ class MockAudioClusteringTask(AbsTaskAudioClustering): type="Clustering", name="MockAudioClusteringTask", main_score="v_measure", - **general_args, + **general_args, ) def load_data(self, **kwargs): mock_audio = [ { "array": np.random.rand(16000), # 1s - "sampling_rate": 16000 - } for _ in range(3) + "sampling_rate": 16000, + } + for _ in range(3) ] - - - labels = [0, 1, 2] + + labels = [0, 1, 2] self.dataset = DatasetDict( { "test": Dataset.from_dict( { - "audio": mock_audio, - "labels": labels, + "audio": mock_audio, + "labels": labels, } ), } diff --git a/tests/test_benchmark/task_grid.py b/tests/test_benchmark/task_grid.py index 47c4953eb6..797352977e 100644 --- a/tests/test_benchmark/task_grid.py +++ b/tests/test_benchmark/task_grid.py @@ -14,11 +14,11 @@ from .mock_tasks import ( MockAny2AnyRetrievalI2TTask, MockAny2AnyRetrievalT2ITask, + MockAudioClusteringTask, MockBitextMiningTask, MockClassificationTask, MockClusteringFastTask, MockClusteringTask, - MockAudioClusteringTask, MockImageClassificationKNNPTTask, MockImageClassificationKNNTask, MockImageClassificationTask, @@ -136,10 +136,6 @@ MockMultilingualImageMultilabelClassificationTask(), ] -MOCK_MAEB_TASK_GRID = [ - MockAudioClusteringTask() -] - MOCK_MIEB_TASK_GRID_AS_STRING = [ t.metadata.name if isinstance(t, AbsTask) else t for t in MOCK_MIEB_TASK_GRID ] @@ -147,3 +143,13 @@ MOCK_MIEB_TASK_REGISTRY = { task.metadata.name: type(task) for task in MOCK_MIEB_TASK_GRID } + +MOCK_MAEB_TASK_GRID = [MockAudioClusteringTask()] + +MOCK_MAEB_TASK_GRID_AS_STRING = [ + t.metadata.name if isinstance(t, AbsTask) else t for t in MOCK_MAEB_TASK_GRID +] + +MOCK_MAEB_TASK_REGISTRY = { + task.metadata.name: type(task) for task in MOCK_MAEB_TASK_GRID +} diff --git a/tests/test_tasks/test_all_abstasks.py b/tests/test_tasks/test_all_abstasks.py index 7a87914f0a..16ff888813 100644 --- a/tests/test_tasks/test_all_abstasks.py +++ b/tests/test_tasks/test_all_abstasks.py @@ -20,13 +20,18 @@ from mteb.overview import TASKS_REGISTRY from ..test_benchmark.task_grid import ( + MOCK_MAEB_TASK_GRID_AS_STRING, MOCK_MIEB_TASK_GRID_AS_STRING, MOCK_TASK_TEST_GRID_AS_STRING, ) logging.basicConfig(level=logging.INFO) -ALL_MOCK_TASKS = MOCK_TASK_TEST_GRID_AS_STRING + MOCK_MIEB_TASK_GRID_AS_STRING +ALL_MOCK_TASKS = ( + MOCK_TASK_TEST_GRID_AS_STRING + + MOCK_MIEB_TASK_GRID_AS_STRING + + MOCK_MAEB_TASK_GRID_AS_STRING +) tasks = [t for t in MTEB().tasks_cls if t.metadata.name not in ALL_MOCK_TASKS] @@ -101,6 +106,7 @@ def test_dataset_availability(): for t in tasks if t.metadata.name not in MOCK_TASK_TEST_GRID_AS_STRING if t.metadata.name not in MOCK_MIEB_TASK_GRID_AS_STRING + if t.metadata.name not in MOCK_MAEB_TASK_GRID_AS_STRING and t.metadata.name != "AfriSentiLangClassification" # HOTFIX: Issue#1777. Remove this line when issue is resolved. ] diff --git a/tests/test_tasks/test_maeb_datasets.py b/tests/test_tasks/test_maeb_datasets.py index fdd4b2f76d..c1a1b21a8b 100644 --- a/tests/test_tasks/test_maeb_datasets.py +++ b/tests/test_tasks/test_maeb_datasets.py @@ -21,4 +21,4 @@ def test_benchmark_audio_encoder(task: str | AbsTask, model: mteb.Encoder): """Test that a task can be fetched and run""" eval = MTEB(tasks=[task]) - eval.run(model, output_folder="tests/results", overwrite_results=True) \ No newline at end of file + eval.run(model, output_folder="tests/results", overwrite_results=True) diff --git a/tests/test_tasks/test_mieb_datasets.py b/tests/test_tasks/test_mieb_datasets.py index 431440d9bd..26e60931ec 100644 --- a/tests/test_tasks/test_mieb_datasets.py +++ b/tests/test_tasks/test_mieb_datasets.py @@ -21,4 +21,4 @@ def test_benchmark_sentence_transformer(task: str | AbsTask, model: mteb.Encoder): """Test that a task can be fetched and run""" eval = MTEB(tasks=[task]) - eval.run(model, output_folder="tests/results", overwrite_results=True) \ No newline at end of file + eval.run(model, output_folder="tests/results", overwrite_results=True)