Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions mteb/abstasks/Image/AbsTaskAny2AnyMultiChoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,272 @@ def process_docs(
return {
f"{split}_{hf_subset}_{k}": v for k, v in collection[hf_subset][split].items()
}


class MultiChoiceEvaluationMixin:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a fan of this solution!

"""A mixin class to enable retrieval tasks to use multiple-choice evaluator;
It is designed for tasks like r-Oxford and r-Pairs that
require masking out different documents in the corpus for each query.

example usage:
class ROxfordHardI2IRetrieval(MultiChoiceEvaluationMixin, AbsTaskAny2AnyRetrieval):

It is for overriding `def evaluate`, `def _evaluate_subset`
and `def _calculate_metrics_from_split` of AbsTaskAny2AnyRetrieval.
"""

def evaluate(
self,
model,
split: str = "test",
*,
encode_kwargs: dict[str, Any] = None,
**kwargs,
):
# Use Any2AnyMultiChoiceEvaluator instead of Any2AnyRetrievalEvaluator
evaluator = Any2AnyMultiChoiceEvaluator(
retriever=model,
task_name=self.metadata.name,
encode_kwargs=encode_kwargs if encode_kwargs is not None else {},
**kwargs,
)

scores = {}
hf_subsets = list(self.hf_subsets) if self.is_multilingual else ["default"]

for hf_subset in hf_subsets:
logger.info(f"Subset: {hf_subset}")

if hf_subset == "default":
corpus, queries, relevant_docs = (
self.corpus[split],
self.queries[split],
self.relevant_docs[split],
)
else:
corpus, queries, relevant_docs = (
self.corpus[hf_subset][split],
self.queries[hf_subset][split],
self.relevant_docs[hf_subset][split],
)
scores[hf_subset] = self._evaluate_subset(
evaluator, corpus, queries, relevant_docs, hf_subset, **kwargs
)
return scores

def _evaluate_subset(
self, retriever, corpus, queries, relevant_docs, hf_subset: str, **kwargs
):
start_time = time()
results = retriever(corpus, queries, relevant_docs)
end_time = time()
logger.info(f"Time taken to retrieve: {end_time - start_time:.2f} seconds")

save_predictions = kwargs.get("save_predictions", False)
export_errors = kwargs.get("export_errors", False)
if save_predictions or export_errors:
output_folder = Path(kwargs.get("output_folder", "results"))
if not os.path.isdir(output_folder):
os.makedirs(output_folder)

if save_predictions:
top_k = kwargs.get("top_k", None)
if top_k is not None:
for qid in list(results.keys()):
doc_ids = set(
sorted(
results[qid], key=lambda x: results[qid][x], reverse=True
)[:top_k]
)
results[qid] = {
k: v for k, v in results[qid].items() if k in doc_ids
}
qrels_save_path = (
output_folder / f"{self.metadata.name}_{hf_subset}_predictions.json"
)

with open(qrels_save_path, "w") as f:
json.dump(results, f)

ndcg, _map, recall, precision, cv_recall, naucs = retriever.evaluate(
relevant_docs,
results,
retriever.k_values,
ignore_identical_ids=self.ignore_identical_ids,
skip_first_result=self.skip_first_result,
)
mrr, naucs_mrr = retriever.evaluate_custom(
relevant_docs, results, retriever.k_values, "mrr"
)
scores = {
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
**{f"cv_recall_at_{k.split('@')[1]}": v for (k, v) in cv_recall.items()},
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
**{
k.replace("@", "_at_").replace("_P", "_precision").lower(): v
for k, v in naucs.items()
},
**{
k.replace("@", "_at_").replace("_P", "_precision").lower(): v
for k, v in naucs_mrr.items()
},
"accuracy": recall["Recall@1"],
}
self._add_main_score(scores)

if export_errors:
errors = {}

top_k = kwargs.get("top_k", 1)
if not save_predictions and top_k == 1:
for qid in results.keys():
doc_scores = results[qid]
sorted_docs = sorted(
doc_scores.items(), key=lambda x: x[1], reverse=True
)[:top_k]
results[qid] = dict(sorted_docs)
for qid, retrieved_docs in results.items():
expected_docs = relevant_docs[qid]
false_positives = [
doc for doc in retrieved_docs if doc not in expected_docs
]
false_negatives = [
doc for doc in expected_docs if doc not in retrieved_docs
]
if false_positives or false_negatives:
errors[qid] = {
"false_positives": false_positives,
"false_negatives": false_negatives,
}

errors_save_path = (
output_folder / f"{self.metadata.name}_{hf_subset}_errors.json"
)
with open(errors_save_path, "w") as f:
json.dump(errors, f)

return scores

def _calculate_metrics_from_split(
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
) -> Any2AnyMutipleChoiceDescriptiveStatistics:
if hf_subset:
queries = self.queries[hf_subset][split]
corpus = self.corpus[hf_subset][split]
relevant_docs = self.relevant_docs[hf_subset][split]
elif compute_overall:
queries = {}
corpus = {}
relevant_docs = {}
for hf_subset in self.metadata.eval_langs:
queries.update(process_docs(self.queries, hf_subset, split))
corpus.update(process_docs(self.corpus, hf_subset, split))
relevant_docs.update(
process_relevant_docs(self.relevant_docs, hf_subset, split)
)
else:
queries = self.queries[split]
corpus = self.corpus[split]
relevant_docs = self.relevant_docs[split]

queries_lens, doc_lens = [], []
num_query_images = 0
num_document_images = 0

q_modality = queries[0]["modality"]
unique_queries = len(set(queries["text"])) if "text" in q_modality else 0

for query in tqdm.tqdm(queries, desc="queries:"):
if "text" in q_modality:
text_query = query["text"]
queries_lens.append(len(text_query))
if "image" in q_modality:
num_query_images += 1

d_modality = corpus[0]["modality"]
unique_documents = len(set(corpus["text"])) if "text" in d_modality else 0

for doc in tqdm.tqdm(corpus, desc="docs:"):
if "text" in d_modality:
text_doc = doc["text"]
doc_lens.append(len(text_doc))
if "image" in d_modality:
num_document_images += 1

total_doc_len = sum(doc_lens)
total_query_len = sum(queries_lens)
num_documents = len(corpus)
num_queries = len(queries)

d_modality = corpus[0]["modality"]
imgs = [doc["image"] for doc in corpus if "image" in d_modality]
d_img_widths, d_img_heights = [], []
for img in imgs:
width, height = img.size
d_img_widths.append(height)
d_img_heights.append(width)

q_modality = queries[0]["modality"]
imgs = [query["image"] for query in queries if "image" in q_modality]
q_img_widths, q_img_heights = [], []
for img in imgs:
width, height = img.size
q_img_widths.append(height)
q_img_heights.append(width)

# create a list of number of relevant docs per query
queries_set = set(queries["id"])
qrels_lengths = [
len([v for k, v in relevant_docs[qid].items() if v != 0])
for qid in tqdm.tqdm(relevant_docs.keys(), desc="qrels:")
if qid in queries_set
]
num_qrels = sum(qrels_lengths)
qrels_per_doc = num_qrels / len(relevant_docs) if num_queries else 0
unique_qrels = len({doc for qid in relevant_docs for doc in relevant_docs[qid]})

return Any2AnyMutipleChoiceDescriptiveStatistics(
number_of_characters=total_query_len + total_doc_len,
num_samples=num_documents + num_queries,
num_queries=num_queries,
num_documents=num_documents,
min_document_length=min(doc_lens) if doc_lens else 0,
average_document_length=total_doc_len / len(doc_lens) if doc_lens else 0,
max_document_length=max(doc_lens) if doc_lens else 0,
unique_documents=unique_documents,
min_document_image_width=min(d_img_widths) if d_img_widths else 0,
average_document_image_width=sum(d_img_widths) / len(d_img_widths)
if d_img_widths
else 0,
max_document_image_width=max(d_img_widths) if d_img_widths else 0,
min_document_image_height=min(d_img_heights) if d_img_heights else 0,
average_document_image_height=sum(d_img_heights) / len(d_img_heights)
if d_img_heights
else 0,
max_document_image_height=max(d_img_heights) if d_img_heights else 0,
num_document_images=num_document_images,
min_query_length=min(queries_lens) if queries_lens else 0,
average_query_length=total_query_len / len(queries_lens)
if queries_lens
else 0,
max_query_length=max(queries_lens) if queries_lens else 0,
unique_queries=unique_queries,
num_query_images=num_query_images,
min_query_image_width=min(q_img_widths) if q_img_widths else 0,
average_query_image_width=sum(q_img_widths) / len(q_img_widths)
if q_img_widths
else 0,
max_query_image_width=max(q_img_widths) if q_img_widths else 0,
min_query_image_height=min(q_img_heights) if q_img_heights else 0,
average_query_image_height=sum(q_img_heights) / len(q_img_heights)
if q_img_heights
else 0,
max_query_image_height=max(q_img_heights) if q_img_heights else 0,
min_relevant_docs_per_query=min(qrels_lengths),
average_relevant_docs_per_query=qrels_per_doc,
max_relevant_docs_per_query=max(qrels_lengths),
unique_relevant_docs=unique_qrels,
)
26 changes: 12 additions & 14 deletions mteb/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,14 +1465,13 @@
"STL10ZeroShot",
"SUN397ZeroShot",
"UCF101ZeroShot",
# Any2TextMutipleChoice
# Any2AnyMultipleChoice
"BLINKIT2IMultiChoice",
"BLINKIT2TMultiChoice",
"CVBenchCount",
"CVBenchRelation",
"CVBenchDepth",
"CVBenchDistance",
# Any2AnyMultipleChoice
"BLINKIT2IMultiChoice",
"BLINKIT2TMultiChoice",
# Compositionality
"ImageCoDeT2IMultiChoice",
"AROCocoOrder",
Expand Down Expand Up @@ -1514,13 +1513,13 @@
"NIGHTSI2IRetrieval",
"OVENIT2ITRetrieval",
"OVENIT2TRetrieval",
"ROxfordEasyI2IMultiChoice",
"ROxfordMediumI2IMultiChoice",
"ROxfordHardI2IMultiChoice",
"ROxfordEasyI2IRetrieval",
"ROxfordMediumI2IRetrieval",
"ROxfordHardI2IRetrieval",
"RP2kI2IRetrieval",
"RParisEasyI2IMultiChoice",
"RParisMediumI2IMultiChoice",
"RParisHardI2IMultiChoice",
"RParisEasyI2IRetrieval",
"RParisMediumI2IRetrieval",
"RParisHardI2IRetrieval",
"SciMMIRI2TRetrieval",
"SciMMIRT2IRetrieval",
"SketchyI2IRetrieval",
Expand Down Expand Up @@ -1609,14 +1608,13 @@
"Food101ZeroShot",
"OxfordPetsZeroShot",
"StanfordCarsZeroShot",
# Any2TextMutipleChoice
# Any2AnyMultipleChoice
"BLINKIT2IMultiChoice",
"ImageCoDeT2IMultiChoice",
"CVBenchCount",
"CVBenchRelation",
"CVBenchDepth",
"CVBenchDistance",
# Any2AnyMultipleChoice
"BLINKIT2IMultiChoice",
"ImageCoDeT2IMultiChoice",
# ImageTextPairClassification
"AROCocoOrder",
"AROFlickrOrder",
Expand Down

This file was deleted.

Loading