Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ea4651a
Added wav2vec model wrapper
alisartazkhan Feb 22, 2025
557460a
Added four w2v variants
alisartazkhan Feb 23, 2025
401debb
Update wav2vec_models.py
alisartazkhan Feb 23, 2025
3e80108
Removed run.py test script
sufen-f Feb 27, 2025
6614c51
Added subTask with small sample of dataset for testing
Feb 22, 2025
b471057
Removed test portion of VoiceGender.py task
sufen-f Feb 27, 2025
cd55c46
add commit hash and bibtex
isaac-chung Feb 27, 2025
2ecba04
make lint
isaac-chung Feb 27, 2025
6109f87
update models
isaac-chung Feb 27, 2025
880bcbe
fix circular import
isaac-chung Feb 27, 2025
af38fe4
make VoiceGender discoverable in get_tasks
isaac-chung Feb 27, 2025
3ce93be
add a2a as category for clustering
isaac-chung Feb 27, 2025
a378ec0
specify latest commit hash
isaac-chung Feb 27, 2025
5fe8087
revert linting changes
isaac-chung Feb 27, 2025
af3de65
Based on feedback for model: updated w2v2 revisions and added torchau…
alisartazkhan Feb 28, 2025
0d861db
Added Bibtex for dataset, set data to be test instead of training, sh…
sufen-f Feb 28, 2025
144ec83
Changed task from Voice Gender Clustering to Gender Clustering.
sufen-f Feb 28, 2025
28b6c6b
Fixed mock audio clustering tests
sufen-f Feb 28, 2025
ecde41d
Added dataset metadata
sufen-f Feb 28, 2025
743f832
Linted
sufen-f Feb 28, 2025
298caa5
Passed revision into the w2v2 loader
alisartazkhan Feb 28, 2025
398d024
Merge branch 'models' of https://github.com/sufen-f/mteb_audio into m…
alisartazkhan Feb 28, 2025
22fd001
passed lint check
alisartazkhan Feb 28, 2025
6b2cb71
Linted
sufen-f Feb 28, 2025
3f23a87
Update VoiceGender.py
alisartazkhan Feb 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from time import time
from typing import Any

import torch
import tqdm
from datasets import Features, Value, load_dataset
import torch
import torchaudio

from ...evaluation.evaluators import Any2AnyRetrievalEvaluator
from ..AbsTask import AbsTask, ScoresDict
Expand Down Expand Up @@ -43,9 +42,15 @@ def __init__(
if prefix:
query_file = prefix + "-" + query_file
qrels_folder = prefix + "-" + qrels_folder
self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
self.corpus_file = (
os.path.join(data_folder, corpus_file) if data_folder else corpus_file
)
self.query_file = (
os.path.join(data_folder, query_file) if data_folder else query_file
)
self.qrels_folder = (
os.path.join(data_folder, qrels_folder) if data_folder else None
)
self.qrels_file = qrels_file
self.streaming = streaming
self.keep_in_memory = keep_in_memory
Expand All @@ -57,7 +62,9 @@ def check(fIn: str, ext: str):
if not fIn.endswith(ext):
raise ValueError(f"File {fIn} must have extension {ext}")

def load(self, split="test") -> tuple[
def load(
self, split="test"
) -> tuple[
dict[str, dict[str, str | torch.Tensor]],
dict[str, dict[str, str | torch.Tensor]],
dict[str, dict[str, int]],
Expand All @@ -71,7 +78,9 @@ def load(self, split="test") -> tuple[
if not len(self.corpus):
logger.info("Loading Corpus...")
self._load_corpus()
logger.info("Loaded %d Documents for %s split.", len(self.corpus), split.upper())
logger.info(
"Loaded %d Documents for %s split.", len(self.corpus), split.upper()
)
logger.info("Doc Example: %s", self.corpus[0])

if not len(self.queries):
Expand All @@ -80,8 +89,10 @@ def load(self, split="test") -> tuple[

self._load_qrels(split)
qrels_dict = defaultdict(dict)

def qrels_dict_init(row):
qrels_dict[row["query-id"]][row["corpus-id"]] = int(row["score"])

self.qrels.map(qrels_dict_init)
self.qrels = qrels_dict
self.queries = self.queries.filter(lambda x: x["id"] in self.qrels)
Expand Down Expand Up @@ -150,18 +161,21 @@ def _load_qrels(self, split):
)
if "Q0" in qrels_ds.column_names:
qrels_ds = qrels_ds.remove_columns("Q0")
features = Features({
"query-id": Value("string"),
"corpus-id": Value("string"),
"score": Value("float"),
})
qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"]).cast(features)
features = Features(
{
"query-id": Value("string"),
"corpus-id": Value("string"),
"score": Value("float"),
}
)
qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"]).cast(
features
)
self.qrels = qrels_ds


class AbsTaskAny2AnyRetrieval(AbsTask):
"""
Abstract class for audio-text retrieval experiments.
"""Abstract class for audio-text retrieval experiments.

Child-classes must implement:
- self.corpus: dict[str, dict[str, str]]
Expand Down Expand Up @@ -190,7 +204,9 @@ def load_data(self, **kwargs):
keep_in_memory=False,
).load(split=split)
self.corpus[split], self.queries[split], self.relevant_docs[split] = (
corpus, queries, qrels
corpus,
queries,
qrels,
)
self.data_loaded = True

Expand Down Expand Up @@ -249,10 +265,16 @@ def _evaluate_subset(
if top_k is not None:
for qid in list(results.keys()):
doc_ids = set(
sorted(results[qid], key=lambda x: results[qid][x], reverse=True)[:top_k]
sorted(
results[qid], key=lambda x: results[qid][x], reverse=True
)[:top_k]
)
results[qid] = {k: v for k, v in results[qid].items() if k in doc_ids}
predictions_path = output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json"
results[qid] = {
k: v for k, v in results[qid].items() if k in doc_ids
}
predictions_path = (
output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json"
)
with open(predictions_path, "w") as f:
json.dump(results, f)

Expand Down Expand Up @@ -295,14 +317,20 @@ def _evaluate_subset(
results[qid] = dict(sorted_docs)
for qid, retrieved_docs in results.items():
expected_docs = relevant_docs[qid]
false_positives = [doc for doc in retrieved_docs if doc not in expected_docs]
false_negatives = [doc for doc in expected_docs if doc not in retrieved_docs]
false_positives = [
doc for doc in retrieved_docs if doc not in expected_docs
]
false_negatives = [
doc for doc in expected_docs if doc not in retrieved_docs
]
if false_positives or false_negatives:
errors[qid] = {
"false_positives": false_positives,
"false_negatives": false_negatives,
}
errors_path = output_folder / f"{self.metadata.name}_{hf_subset}_errors.json"
errors_path = (
output_folder / f"{self.metadata.name}_{hf_subset}_errors.json"
)
with open(errors_path, "w") as f:
json.dump(errors, f)

Expand All @@ -319,13 +347,17 @@ def _calculate_metrics_from_split(
def calculate_metadata_metrics(self) -> None:
self.load_data()
all_details = {}
pbar_split = tqdm.tqdm(self.metadata_dict["eval_splits"], desc="Processing Splits...")
pbar_split = tqdm.tqdm(
self.metadata_dict["eval_splits"], desc="Processing Splits..."
)
for split in pbar_split:
pbar_split.set_postfix_str(f"Split: {split}")
logger.info(f"Processing metadata for split {split}")
all_details[split] = {}
if self.is_multilingual:
pbar_lang = tqdm.tqdm(self.relevant_docs.keys(), desc="Processing Languages...")
pbar_lang = tqdm.tqdm(
self.relevant_docs.keys(), desc="Processing Languages..."
)
for lang in pbar_lang:
pbar_lang.set_postfix_str(f"Language: {lang}")
logger.info(f"Processing metadata for language {lang}")
Expand Down Expand Up @@ -358,7 +390,9 @@ def process_language(relevant_docs, queries, corpus, lang=None):
logger.info(f"Average query length{language_description} is {query_len}")
logger.info(f"Number of documents{language_description} is {num_documents}")
logger.info(f"Number of queries{language_description} is {num_queries}")
logger.info(f"Average relevant docs per query{language_description} is {qrels_per_doc}")
logger.info(
f"Average relevant docs per query{language_description} is {qrels_per_doc}"
)
return {
"average_document_length": doc_len,
"average_query_length": query_len,
Expand Down
3 changes: 3 additions & 0 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"Activity recognition",
"Tumor detection",
"Duplicate Detection",
"Gender Clustering",
]

TASK_DOMAIN = Literal[
Expand Down Expand Up @@ -101,6 +102,7 @@
"multiple",
]
TASK_TYPE = Literal[
"AudioClustering",
"BitextMining",
"Classification",
"MultilabelClassification",
Expand Down Expand Up @@ -137,6 +139,7 @@
"i2it", # image-to-image+text
"t2it", # text-to-image+text
"it2it", # image+text-to-image+text
"a2a", # audio-to-audio
]

ANNOTATOR_TYPE = Literal[
Expand Down
28 changes: 22 additions & 6 deletions mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import numpy as np
import pytrec_eval
import torch
from datasets import Dataset
import torchaudio
from datasets import Dataset
from torch.utils.data import DataLoader

from mteb.encoder_interface import Encoder, PromptType
Expand All @@ -33,8 +33,10 @@

logger = logging.getLogger(__name__)


# A default transform for audio; replace with a more meaningful transform as needed.
DEFAULT_AUDIO_TRANSFORM = lambda x: x
def DEFAULT_AUDIO_TRANSFORM(x):
return x


class AudioDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -342,7 +344,9 @@ def evaluate(
if qid == pid:
results[qid].pop(pid)
else:
logger.debug("Not ignoring identical query and document ids. Set ignore_identical_ids=True to ignore.")
logger.debug(
"Not ignoring identical query and document ids. Set ignore_identical_ids=True to ignore."
)

all_ndcgs, all_aps, all_recalls, all_precisions, all_cv_recalls = (
{},
Expand Down Expand Up @@ -404,7 +408,9 @@ def evaluate(
_map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5)
recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5)
precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5)
cv_recall[f"CV_Recall@{k}"] = round(sum(cv_recall[f"CV_Recall@{k}"]) / len(scores), 5)
cv_recall[f"CV_Recall@{k}"] = round(
sum(cv_recall[f"CV_Recall@{k}"]) / len(scores), 5
)

naucs = Any2AnyRetrievalEvaluator.evaluate_abstention(
results,
Expand All @@ -427,7 +433,13 @@ def evaluate_custom(
metric_scores = recall_cap(qrels, results, k_values, output_type)
elif metric.lower() in ["hole", "hole@k"]:
metric_scores = hole(qrels, results, k_values, output_type)
elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]:
elif metric.lower() in [
"acc",
"top_k_acc",
"accuracy",
"accuracy@k",
"top_k_accuracy",
]:
metric_scores = top_k_accuracy(qrels, results, k_values, output_type)
naucs = Any2AnyRetrievalEvaluator.evaluate_abstention(results, metric_scores)
metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()}
Expand Down Expand Up @@ -460,5 +472,9 @@ def calculate_cv_style_recall(
cv_recalls = {}
for query_id, relevant_docs in qrels.items():
retrieved_docs = list(results.get(query_id, {}).keys())[:k]
cv_recalls[query_id] = 1.0 if any(doc_id in relevant_docs for doc_id in retrieved_docs) else 0.0
cv_recalls[query_id] = (
1.0
if any(doc_id in relevant_docs for doc_id in retrieved_docs)
else 0.0
)
return cv_recalls
2 changes: 2 additions & 0 deletions mteb/models/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
vlm2vec_models,
voyage_models,
voyage_v,
wav2vec_models,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,6 +137,7 @@
uae_models,
voyage_models,
fa_models,
wav2vec_models,
]
MODEL_REGISTRY = {}

Expand Down
Loading