Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions mteb/abstasks/Image/AbsTaskAny2TextMultipleChoice.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,65 @@
from __future__ import annotations

import logging
from collections import Counter
from typing import Any

from datasets import Dataset

from ...encoder_interface import Encoder
from ...evaluation.evaluators import Any2TextMultipleChoiceEvaluator
from ..AbsTask import AbsTask, ScoresDict
from ..TaskMetadata import DescriptiveStatistics

logger = logging.getLogger(__name__)


class Any2TextMutipleChoiceDescriptiveStatistics(DescriptiveStatistics):
"""Descriptive statistics for Any2TextMutipleChoice

Attributes:
num_samples: number of samples in the dataset.

min_image_width: Minimum width of images
average_image_width: Average width of images
max_image_width: Maximum width of images

min_image_height: Minimum height of images
average_image_height: Average height of images
max_image_height: Maximum height of images

min_num_choices: Minimum number of choices
average_num_choices: Average number of choices
max_num_choices: Maximum number of choices

answers: dict of answer frequencies

min_question_length: Minimum length of questions
average_question_length: Average length of questions
max_question_length: Maximum length of questions
"""

num_samples: int

min_image_width: float
average_image_width: float
max_image_width: float

min_image_height: float
average_image_height: float
max_image_height: float

min_num_choices: int
average_num_choices: float
max_num_choices: int

answers: dict[str, dict[str, int]]

min_question_length: int
average_question_length: float
max_question_length: int


class AbsTaskAny2TextMultipleChoice(AbsTask):
"""Abstract class for Any to Text Multiple Choice tasks,
where the queries and be either text or image, or both.
Expand All @@ -34,8 +82,41 @@ def _add_main_score(self, scores) -> None:

def _calculate_metrics_from_split(
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
):
pass
) -> Any2TextMutipleChoiceDescriptiveStatistics:
imgs = self.dataset[split][self.query_column_names["image"]]
questions = self.dataset[split][self.query_column_names["text"]]
choices = self.dataset[split][self.choices_column_name]
answers = self.dataset[split][self.label_column_name]
Comment on lines +86 to +89
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, there are no multilingual datasets?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. Eng only.


num_samples = len(answers)
answer_count = Counter(answers)
img_widths, img_heights = [], []
for img in imgs:
width, height = img.size
img_heights.append(height)
img_widths.append(width)

choices_len = [len(c) for c in choices]
questions_len = [len(q) for q in questions]

return Any2TextMutipleChoiceDescriptiveStatistics(
num_samples=num_samples,
min_image_width=min(img_widths),
average_image_width=sum(img_widths) / len(img_widths),
max_image_width=max(img_widths),
min_image_height=min(img_heights),
average_image_height=sum(img_heights) / len(img_heights),
max_image_height=max(img_heights),
min_num_choices=min(choices_len),
average_num_choices=sum(choices_len) / len(choices_len),
max_num_choices=max(choices_len),
min_question_length=min(questions_len),
average_question_length=sum(questions_len) / len(questions_len),
max_question_length=max(questions_len),
answers={
str(answer): {"count": count} for answer, count in answer_count.items()
},
)

def _evaluate_subset(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"test": {
"num_samples": 788,
"min_image_width": 200,
"average_image_width": 757.6789340101523,
"max_image_width": 2200,
"min_image_height": 181,
"average_image_height": 631.3147208121827,
"max_image_height": 2200,
"min_num_choices": 4,
"average_num_choices": 4.550761421319797,
"max_num_choices": 6,
"min_question_length": 30,
"average_question_length": 34.35406091370558,
"max_question_length": 45,
"answers": {
"2": {
"count": 169
},
"4": {
"count": 63
},
"3": {
"count": 167
},
"1": {
"count": 184
},
"0": {
"count": 182
},
"5": {
"count": 23
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"test": {
"num_samples": 600,
"min_image_width": 561,
"average_image_width": 1090.9616666666666,
"max_image_width": 1600,
"min_image_height": 427,
"average_image_height": 715.985,
"max_image_height": 900,
"min_num_choices": 2,
"average_num_choices": 2.0,
"max_num_choices": 2,
"min_question_length": 130,
"average_question_length": 136.04333333333332,
"max_question_length": 147,
"answers": {
"0": {
"count": 300
},
"1": {
"count": 300
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"test": {
"num_samples": 600,
"min_image_width": 561,
"average_image_width": 1099.2883333333334,
"max_image_width": 1600,
"min_image_height": 427,
"average_image_height": 720.9983333333333,
"max_image_height": 900,
"min_num_choices": 2,
"average_num_choices": 2.0,
"max_num_choices": 2,
"min_question_length": 204,
"average_question_length": 212.40333333333334,
"max_question_length": 223,
"answers": {
"0": {
"count": 303
},
"1": {
"count": 297
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"test": {
"num_samples": 650,
"min_image_width": 189,
"average_image_width": 546.3169230769231,
"max_image_width": 2200,
"min_image_height": 190,
"average_image_height": 448.4492307692308,
"max_image_height": 2200,
"min_num_choices": 2,
"average_num_choices": 2.0,
"max_num_choices": 2,
"min_question_length": 132,
"average_question_length": 181.45846153846153,
"max_question_length": 224,
"answers": {
"0": {
"count": 327
},
"1": {
"count": 323
}
}
}
}