Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 232 additions & 36 deletions mteb/abstasks/Image/AbsTaskAny2AnyMultiChoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mteb.abstasks.AbsTask import AbsTask, ScoresDict

from ...evaluation.evaluators import Any2AnyMultiChoiceEvaluator
from ..TaskMetadata import DescriptiveStatistics

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,6 +187,95 @@ def _load_qrels(self, split):
self.qrels = qrels_ds


class Any2AnyMutipleChoiceDescriptiveStatistics(DescriptiveStatistics):
"""Descriptive statistics for Any2TextMutipleChoice

Attributes:
num_samples: Number of queries and documents
num_queries: number of queries in the dataset
num_documents: Number of documents
number_of_characters: Total number of text characters in the dataset

For text only:
min_document_length: Minimum length of documents
average_document_length: Average length of documents
max_document_length: Maximum length of documents
unique_documents: Number of unique documents

For text only:
min_query_length: Minimum length of queries
average_query_length: Average length of queries
max_query_length: Maximum length of queries
unique_queries: Number of unique queries

For images:
num_query_images: Number of query images
num_document_images: Number of document images

For images:
min_document_image_width: Minimum width of document images
average_document_image_width: Average width of document images
max_document_image_width: Maximum width of document images
min_document_image_height: Minimum height of document images
average_document_image_height: Average height of document images
max_document_image_height: Maximum height of document images

For images:
min_query_image_width: Minimum width of query images
average_query_image_width: Average width of query images
max_query_image_width: Maximum width of query images
min_query_image_height: Minimum height of query images
average_query_image_height: Average height of query images
max_query_image_height: Maximum height of query images

min_relevant_docs_per_query: Minimum number of relevant documents per query
average_relevant_docs_per_query: Average number of relevant documents per query
max_relevant_docs_per_query: Maximum number of relevant documents per query
unique_relevant_docs: Number of unique relevant documents

min_irrelevant_docs_per_query: Minimum number of irrelevant documents per query
average_irrelevant_docs_per_query: Average number of irrelevant documents per query
max_irrelevant_docs_per_query: Maximum number of irrelevant documents per query
unique_irrelevant_docs: Number of unique irrelevant documents
"""

num_samples: int
num_queries: int
num_documents: int
number_of_characters: int

min_document_length: int
average_document_length: float
max_document_length: int
unique_documents: int
num_document_images: int

min_document_image_width: float
average_document_image_width: float
max_document_image_width: float
min_document_image_height: float
average_document_image_height: float
max_document_image_height: float

min_query_length: int
average_query_length: float
max_query_length: int
unique_queries: int
num_query_images: int

min_query_image_width: float
average_query_image_width: float
max_query_image_width: float
min_query_image_height: float
average_query_image_height: float
max_query_image_height: float

min_relevant_docs_per_query: int
average_relevant_docs_per_query: float
max_relevant_docs_per_query: int
unique_relevant_docs: int


class AbsTaskAny2AnyMultiChoice(AbsTask):
"""Abstract class for Any2Any multiple choice experiments

Expand Down Expand Up @@ -376,41 +466,124 @@ def _add_main_score(self, scores: ScoresDict) -> None:

def _calculate_metrics_from_split(
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
):
pass

def calculate_metadata_metrics(self) -> None:
self.load_data()

all_details = {}
pbar_split = tqdm.tqdm(
self.metadata_dict["eval_splits"], desc="Processing Splits..."
)
for split in pbar_split:
pbar_split.set_postfix_str(f"Split: {split}")
logger.info(f"Processing metadata for split {split}")
all_details[split] = {}
if self.is_multilingual:
pbar_lang = tqdm.tqdm(
self.relevant_docs.keys(), desc="Processing Languages..."
)
for lang in pbar_lang:
pbar_lang.set_postfix_str(f"Language: {lang}")
logger.info(f"Processing metadata for language {lang}")
split_details = process_language(
self.relevant_docs[lang][split],
self.queries[lang][split],
self.corpus[lang][split],
lang,
)
all_details[split][lang] = split_details
else:
split_details = process_language(
self.relevant_docs[split], self.queries[split], self.corpus[split]
) -> Any2AnyMutipleChoiceDescriptiveStatistics:
if hf_subset:
queries = self.queries[hf_subset][split]
corpus = self.corpus[hf_subset][split]
relevant_docs = self.relevant_docs[hf_subset][split]
elif compute_overall:
queries = {}
corpus = {}
relevant_docs = {}
for hf_subset in self.metadata.eval_langs:
queries.update(process_docs(self.queries, hf_subset, split))
corpus.update(process_docs(self.corpus, hf_subset, split))
relevant_docs.update(
process_relevant_docs(self.relevant_docs, hf_subset, split)
)
all_details[split] = split_details

return all_details
else:
queries = self.queries[split]
corpus = self.corpus[split]
relevant_docs = self.relevant_docs[split]

queries_lens, doc_lens = [], []
num_query_images = 0
num_document_images = 0

q_modality = queries[0]["modality"]
unique_queries = len(set(queries["text"])) if "text" in q_modality else 0

for query in tqdm.tqdm(queries, desc="queries:"):
if "text" in q_modality:
text_query = query["text"]
queries_lens.append(len(text_query))
if "image" in q_modality:
num_query_images += 1

d_modality = corpus[0]["modality"]
unique_documents = len(set(corpus["text"])) if "text" in d_modality else 0

for doc in tqdm.tqdm(corpus, desc="docs:"):
if "text" in d_modality:
text_doc = doc["text"]
doc_lens.append(len(text_doc))
if "image" in d_modality:
num_document_images += 1

total_doc_len = sum(doc_lens)
total_query_len = sum(queries_lens)
num_documents = len(corpus)
num_queries = len(queries)

d_modality = corpus[0]["modality"]
imgs = [doc["image"] for doc in corpus if "image" in d_modality]
d_img_widths, d_img_heights = [], []
for img in imgs:
width, height = img.size
d_img_widths.append(height)
d_img_heights.append(width)

q_modality = queries[0]["modality"]
imgs = [query["image"] for query in queries if "image" in q_modality]
q_img_widths, q_img_heights = [], []
for img in imgs:
width, height = img.size
q_img_widths.append(height)
q_img_heights.append(width)

# create a list of number of relevant docs per query
queries_set = set(queries["id"])
qrels_lengths = [
len(relevant_docs[qid])
for qid in tqdm.tqdm(relevant_docs.keys(), desc="qrels:")
if qid in queries_set
]
num_qrels = sum(qrels_lengths)
qrels_per_doc = num_qrels / len(relevant_docs) if num_queries else 0
unique_qrels = len({doc for qid in relevant_docs for doc in relevant_docs[qid]})

return Any2AnyMutipleChoiceDescriptiveStatistics(
number_of_characters=total_query_len + total_doc_len,
num_samples=num_documents + num_queries,
num_queries=num_queries,
num_documents=num_documents,
min_document_length=min(doc_lens) if doc_lens else 0,
average_document_length=total_doc_len / len(doc_lens) if doc_lens else 0,
max_document_length=max(doc_lens) if doc_lens else 0,
unique_documents=unique_documents,
min_document_image_width=min(d_img_widths) if d_img_widths else 0,
average_document_image_width=sum(d_img_widths) / len(d_img_widths)
if d_img_widths
else 0,
max_document_image_width=max(d_img_widths) if d_img_widths else 0,
min_document_image_height=min(d_img_heights) if d_img_heights else 0,
average_document_image_height=sum(d_img_heights) / len(d_img_heights)
if d_img_heights
else 0,
max_document_image_height=max(d_img_heights) if d_img_heights else 0,
num_document_images=num_document_images,
min_query_length=min(queries_lens) if queries_lens else 0,
average_query_length=total_query_len / len(queries_lens)
if queries_lens
else 0,
max_query_length=max(queries_lens) if queries_lens else 0,
unique_queries=unique_queries,
num_query_images=num_query_images,
min_query_image_width=min(q_img_widths) if q_img_widths else 0,
average_query_image_width=sum(q_img_widths) / len(q_img_widths)
if q_img_widths
else 0,
max_query_image_width=max(q_img_widths) if q_img_widths else 0,
min_query_image_height=min(q_img_heights) if q_img_heights else 0,
average_query_image_height=sum(q_img_heights) / len(q_img_heights)
if q_img_heights
else 0,
max_query_image_height=max(q_img_heights) if q_img_heights else 0,
min_relevant_docs_per_query=min(qrels_lengths),
average_relevant_docs_per_query=qrels_per_doc,
max_relevant_docs_per_query=max(qrels_lengths),
unique_relevant_docs=unique_qrels,
)


def process_language(relevant_docs, queries, corpus, lang=None):
Expand Down Expand Up @@ -450,13 +623,36 @@ def process_language(relevant_docs, queries, corpus, lang=None):
def calculate_length(queries, corpus):
queries_lens = []
doc_lens = []
for query in queries.values():
for query in queries:
queries_lens.append(len(query))

for doc in corpus.values():
for doc in corpus:
if isinstance(doc, Image.Image):
doc_lens.append(1.0) # for image append 1. Can perhaps be removed.

doc_len = sum(doc_lens) / len(doc_lens) if doc_lens else 0
query_len = sum(queries_lens) / len(queries_lens) if queries_lens else 0
return query_len, doc_len


def process_relevant_docs(
collection: dict[str, dict[str, dict[str, dict[str, int]]]],
hf_subset: str,
split: str,
) -> dict[str, dict[str, int]]:
"""Collections can contain overlapping ids in different splits. Prepend split to avoid this"""
return_collection = {}
for query_id, relevant in collection[hf_subset][split].items():
return_collection[f"{split}_{hf_subset}_{query_id}"] = {
f"{split}_{hf_subset}_{doc_id}": value for doc_id, value in relevant.items()
}
return return_collection


def process_docs(
collection: dict[str, dict[str, dict[str, str] | str]], hf_subset: str, split: str
) -> dict[str, str]:
"""Collections can contain overlapping ids in different splits. Prepend split to avoid this"""
return {
f"{split}_{hf_subset}_{k}": v for k, v in collection[hf_subset][split].items()
}
62 changes: 60 additions & 2 deletions mteb/abstasks/Image/AbsTaskImageTextPairClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,35 @@
from ...encoder_interface import Encoder
from ...evaluation.evaluators import ImageTextPairClassificationEvaluator
from ..AbsTask import AbsTask, ScoresDict
from ..TaskMetadata import DescriptiveStatistics

logger = logging.getLogger(__name__)


class ImageTextPairClassificationDescriptiveStatistics(DescriptiveStatistics):
"""Descriptive statistics for ImageTextPairClassification

Attributes:
num_samples: number of samples in the dataset.
num_images: number of images in the dataset.
num_texts: number of texts in the dataset.
num_unique_texts: number of unique texts in the dataset.

min_text_length: Minimum length of texts
average_text_length: Average length of texts
max_text_length: Maximum length of texts
"""

num_samples: int
num_images: int
num_texts: int
num_unique_texts: int

min_text_length: int
average_text_length: float
max_text_length: int


class AbsTaskImageTextPairClassification(AbsTask):
"""Abstract class for Image Text Pair Classification tasks,
e.g. Compositionality evaluation.
Expand All @@ -35,8 +60,41 @@ def _add_main_score(self, scores) -> None:

def _calculate_metrics_from_split(
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
):
pass
) -> ImageTextPairClassificationDescriptiveStatistics:
dataset = (
self.dataset[split] if hf_subset is None else self.dataset[hf_subset][split]
)
num_samples = len(dataset)

if isinstance(self.images_column_names, str):
num_images = list(dataset[self.images_column_names])
elif isinstance(self.images_column_names, list):
num_images = sum(
[len(dataset[img_column]) for img_column in self.images_column_names]
)

if isinstance(self.texts_column_names, str):
texts = list(dataset[self.texts_column_names])
unique_texts = set(texts)
text_lengths = [len(text) for text in texts]
elif isinstance(self.texts_column_names, list):
texts = [
text
for text_column in self.texts_column_names
for text in dataset[text_column]
]
unique_texts = set(texts)
text_lengths = [len(text) for text in texts]

return ImageTextPairClassificationDescriptiveStatistics(
num_samples=num_samples,
num_images=num_images,
num_texts=len(texts),
num_unique_texts=len(unique_texts),
min_text_length=min(text_lengths),
average_text_length=sum(text_lengths) / len(text_lengths),
max_text_length=max(text_lengths),
)

def _evaluate_subset(
self,
Expand Down
11 changes: 11 additions & 0 deletions mteb/descriptive_stats/Image/Compositionality/SugarCrepe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"test": {
"num_samples": 7511,
"num_images": 7511,
"num_texts": 15022,
"num_unique_texts": 11844,
"min_text_length": 24,
"average_text_length": 56.48681933164692,
"max_text_length": 210
}
}
Loading