Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions mteb/abstasks/Audio/AbsTaskAny2AnyRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from time import time
from typing import Any

import torch
import tqdm
from datasets import Features, Value, load_dataset
import torch
import torchaudio

from ...evaluation.evaluators import Any2AnyRetrievalEvaluator
from ..AbsTask import AbsTask, ScoresDict
Expand Down Expand Up @@ -43,9 +42,15 @@ def __init__(
if prefix:
query_file = prefix + "-" + query_file
qrels_folder = prefix + "-" + qrels_folder
self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
self.corpus_file = (
os.path.join(data_folder, corpus_file) if data_folder else corpus_file
)
self.query_file = (
os.path.join(data_folder, query_file) if data_folder else query_file
)
self.qrels_folder = (
os.path.join(data_folder, qrels_folder) if data_folder else None
)
self.qrels_file = qrels_file
self.streaming = streaming
self.keep_in_memory = keep_in_memory
Expand All @@ -57,7 +62,9 @@ def check(fIn: str, ext: str):
if not fIn.endswith(ext):
raise ValueError(f"File {fIn} must have extension {ext}")

def load(self, split="test") -> tuple[
def load(
self, split="test"
) -> tuple[
dict[str, dict[str, str | torch.Tensor]],
dict[str, dict[str, str | torch.Tensor]],
dict[str, dict[str, int]],
Expand All @@ -71,7 +78,9 @@ def load(self, split="test") -> tuple[
if not len(self.corpus):
logger.info("Loading Corpus...")
self._load_corpus()
logger.info("Loaded %d Documents for %s split.", len(self.corpus), split.upper())
logger.info(
"Loaded %d Documents for %s split.", len(self.corpus), split.upper()
)
logger.info("Doc Example: %s", self.corpus[0])

if not len(self.queries):
Expand All @@ -80,8 +89,10 @@ def load(self, split="test") -> tuple[

self._load_qrels(split)
qrels_dict = defaultdict(dict)

def qrels_dict_init(row):
qrels_dict[row["query-id"]][row["corpus-id"]] = int(row["score"])

self.qrels.map(qrels_dict_init)
self.qrels = qrels_dict
self.queries = self.queries.filter(lambda x: x["id"] in self.qrels)
Expand Down Expand Up @@ -150,18 +161,21 @@ def _load_qrels(self, split):
)
if "Q0" in qrels_ds.column_names:
qrels_ds = qrels_ds.remove_columns("Q0")
features = Features({
"query-id": Value("string"),
"corpus-id": Value("string"),
"score": Value("float"),
})
qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"]).cast(features)
features = Features(
{
"query-id": Value("string"),
"corpus-id": Value("string"),
"score": Value("float"),
}
)
qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"]).cast(
features
)
self.qrels = qrels_ds


class AbsTaskAny2AnyRetrieval(AbsTask):
"""
Abstract class for audio-text retrieval experiments.
"""Abstract class for audio-text retrieval experiments.

Child-classes must implement:
- self.corpus: dict[str, dict[str, str]]
Expand Down Expand Up @@ -190,7 +204,9 @@ def load_data(self, **kwargs):
keep_in_memory=False,
).load(split=split)
self.corpus[split], self.queries[split], self.relevant_docs[split] = (
corpus, queries, qrels
corpus,
queries,
qrels,
)
self.data_loaded = True

Expand Down Expand Up @@ -249,10 +265,16 @@ def _evaluate_subset(
if top_k is not None:
for qid in list(results.keys()):
doc_ids = set(
sorted(results[qid], key=lambda x: results[qid][x], reverse=True)[:top_k]
sorted(
results[qid], key=lambda x: results[qid][x], reverse=True
)[:top_k]
)
results[qid] = {k: v for k, v in results[qid].items() if k in doc_ids}
predictions_path = output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json"
results[qid] = {
k: v for k, v in results[qid].items() if k in doc_ids
}
predictions_path = (
output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json"
)
with open(predictions_path, "w") as f:
json.dump(results, f)

Expand Down Expand Up @@ -295,14 +317,20 @@ def _evaluate_subset(
results[qid] = dict(sorted_docs)
for qid, retrieved_docs in results.items():
expected_docs = relevant_docs[qid]
false_positives = [doc for doc in retrieved_docs if doc not in expected_docs]
false_negatives = [doc for doc in expected_docs if doc not in retrieved_docs]
false_positives = [
doc for doc in retrieved_docs if doc not in expected_docs
]
false_negatives = [
doc for doc in expected_docs if doc not in retrieved_docs
]
if false_positives or false_negatives:
errors[qid] = {
"false_positives": false_positives,
"false_negatives": false_negatives,
}
errors_path = output_folder / f"{self.metadata.name}_{hf_subset}_errors.json"
errors_path = (
output_folder / f"{self.metadata.name}_{hf_subset}_errors.json"
)
with open(errors_path, "w") as f:
json.dump(errors, f)

Expand All @@ -319,13 +347,17 @@ def _calculate_metrics_from_split(
def calculate_metadata_metrics(self) -> None:
self.load_data()
all_details = {}
pbar_split = tqdm.tqdm(self.metadata_dict["eval_splits"], desc="Processing Splits...")
pbar_split = tqdm.tqdm(
self.metadata_dict["eval_splits"], desc="Processing Splits..."
)
for split in pbar_split:
pbar_split.set_postfix_str(f"Split: {split}")
logger.info(f"Processing metadata for split {split}")
all_details[split] = {}
if self.is_multilingual:
pbar_lang = tqdm.tqdm(self.relevant_docs.keys(), desc="Processing Languages...")
pbar_lang = tqdm.tqdm(
self.relevant_docs.keys(), desc="Processing Languages..."
)
for lang in pbar_lang:
pbar_lang.set_postfix_str(f"Language: {lang}")
logger.info(f"Processing metadata for language {lang}")
Expand Down Expand Up @@ -358,7 +390,9 @@ def process_language(relevant_docs, queries, corpus, lang=None):
logger.info(f"Average query length{language_description} is {query_len}")
logger.info(f"Number of documents{language_description} is {num_documents}")
logger.info(f"Number of queries{language_description} is {num_queries}")
logger.info(f"Average relevant docs per query{language_description} is {qrels_per_doc}")
logger.info(
f"Average relevant docs per query{language_description} is {qrels_per_doc}"
)
return {
"average_document_length": doc_len,
"average_query_length": query_len,
Expand Down
3 changes: 2 additions & 1 deletion mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"Activity recognition",
"Tumor detection",
"Duplicate Detection",
"Voice Gender Clustering",
"Gender Clustering",
"Voice Emotion Clustering",
]

Expand Down Expand Up @@ -141,6 +141,7 @@
"i2it", # image-to-image+text
"t2it", # text-to-image+text
"it2it", # image+text-to-image+text
"a2a", # audio-to-audio
]

ANNOTATOR_TYPE = Literal[
Expand Down
28 changes: 22 additions & 6 deletions mteb/evaluation/evaluators/Audio/Any2AnyRetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import numpy as np
import pytrec_eval
import torch
from datasets import Dataset
import torchaudio
from datasets import Dataset
from torch.utils.data import DataLoader

from mteb.encoder_interface import Encoder, PromptType
Expand All @@ -33,8 +33,10 @@

logger = logging.getLogger(__name__)


# A default transform for audio; replace with a more meaningful transform as needed.
DEFAULT_AUDIO_TRANSFORM = lambda x: x
def DEFAULT_AUDIO_TRANSFORM(x):
return x


class AudioDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -342,7 +344,9 @@ def evaluate(
if qid == pid:
results[qid].pop(pid)
else:
logger.debug("Not ignoring identical query and document ids. Set ignore_identical_ids=True to ignore.")
logger.debug(
"Not ignoring identical query and document ids. Set ignore_identical_ids=True to ignore."
)

all_ndcgs, all_aps, all_recalls, all_precisions, all_cv_recalls = (
{},
Expand Down Expand Up @@ -404,7 +408,9 @@ def evaluate(
_map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5)
recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5)
precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5)
cv_recall[f"CV_Recall@{k}"] = round(sum(cv_recall[f"CV_Recall@{k}"]) / len(scores), 5)
cv_recall[f"CV_Recall@{k}"] = round(
sum(cv_recall[f"CV_Recall@{k}"]) / len(scores), 5
)

naucs = Any2AnyRetrievalEvaluator.evaluate_abstention(
results,
Expand All @@ -427,7 +433,13 @@ def evaluate_custom(
metric_scores = recall_cap(qrels, results, k_values, output_type)
elif metric.lower() in ["hole", "hole@k"]:
metric_scores = hole(qrels, results, k_values, output_type)
elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]:
elif metric.lower() in [
"acc",
"top_k_acc",
"accuracy",
"accuracy@k",
"top_k_accuracy",
]:
metric_scores = top_k_accuracy(qrels, results, k_values, output_type)
naucs = Any2AnyRetrievalEvaluator.evaluate_abstention(results, metric_scores)
metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()}
Expand Down Expand Up @@ -460,5 +472,9 @@ def calculate_cv_style_recall(
cv_recalls = {}
for query_id, relevant_docs in qrels.items():
retrieved_docs = list(results.get(query_id, {}).keys())[:k]
cv_recalls[query_id] = 1.0 if any(doc_id in relevant_docs for doc_id in retrieved_docs) else 0.0
cv_recalls[query_id] = (
1.0
if any(doc_id in relevant_docs for doc_id in retrieved_docs)
else 0.0
)
return cv_recalls
Loading