Skip to content

Commit

Permalink
Merge 0cdbaca into 7d850df
Browse files Browse the repository at this point in the history
  • Loading branch information
wangbo-zhao authored Jun 27, 2023
2 parents 7d850df + 0cdbaca commit c20197f
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
]

if WITH_MULTIMODAL:
from .chartqa import ChartQA
from .coco_caption import COCOCaption
from .coco_retrieval import COCORetrieval
from .coco_vqa import COCOVQA
Expand All @@ -54,5 +55,5 @@
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
'VSR', 'VizWiz', 'OCRVQA'
'VSR', 'VizWiz', 'OCRVQA', 'ChartQA'
])
115 changes: 115 additions & 0 deletions mmpretrain/datasets/chartqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List

import mmengine
from mmengine.dataset import BaseDataset
from mmengine.utils import is_abs

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class ChartQA(BaseDataset):
"""ChartQA dataset.
dataset:https://github.com/vis-nlp/ChartQA
folder structure:
data/chartqa
├── test
│ ├── png
│ ├── tables
│ ├── test_human.json
│ └── test_augmented.json
├── train
│ ├── png
│ ├── tables
│ ├── train_human.json
│ └── train_augmented.json
└── val
├── png
├── tables
├── val_human.json
└── val_augmented.json
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
question_file (str): Question file path.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)

def _join_prefix(self):
# Automatically join annotation file path with `self.root` if
# `self.ann_file` is not an absolute path.
if not any(is_abs(sub_ann_file)
for sub_ann_file in self.ann_file) and self.ann_file:
self.ann_file = [
osp.join(self.data_root, sub_ann_file)
for sub_ann_file in self.ann_file
]
# Automatically join data directory with `self.root` if path value in
# `self.data_prefix` is not an absolute path.
for data_key, prefix in self.data_prefix.items():
if isinstance(prefix, str):
if not is_abs(prefix):
self.data_prefix[data_key] = osp.join(
self.data_root, prefix)
else:
self.data_prefix[data_key] = prefix
else:
raise TypeError('prefix should be a string, but got '
f'{type(prefix)}')

def load_data_list(self) -> List[dict]:
"""Load data list."""
data_list = []

for sub_ann_file in self.ann_file:

annotations = mmengine.load(sub_ann_file)

for ann in annotations:

# ann example
# {
# 'imgname': '41699051005347.png'
# 'query': 'How many food item i...bar graph?',
# 'label': '14'
# }

data_info = dict(question=ann['query'])
data_info['image_id'] = ann['imgname']

img_path = mmengine.join_path(self.data_prefix['img_path'],
ann['imgname'])

data_info['img_path'] = img_path
data_info['gt_answer'] = ann['label']

if 'human' in sub_ann_file:
data_info['sub_set'] = 'ChartQA-H'
elif 'augmented' in sub_ann_file:
data_info['sub_set'] = 'ChartQA-M'
else:
raise ValueError(
f'Do not support to subset {sub_ann_file}.')

data_list.append(data_info)

return data_list
3 changes: 2 additions & 1 deletion mmpretrain/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .caption import COCOCaption
from .chartqa import ChartQARelaxACC
from .gqa import GQAAcc
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
Expand All @@ -16,5 +17,5 @@
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
'RetrievalAveragePrecision'
'RetrievalAveragePrecision', 'ChartQARelaxACC'
]
130 changes: 130 additions & 0 deletions mmpretrain/evaluation/metrics/chartqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import List, Optional

from mmengine.evaluator import BaseMetric

from mmpretrain.registry import METRICS
from .vqa import _process_digit_article, _process_punctuation


@METRICS.register_module()
class ChartQARelaxACC(BaseMetric):
'''ChartQARelaxACC.
Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
'''
default_prefix = 'ChartQARelaxACC'

def __init__(self,
full_score_weight: float = 0.3,
collect_device: str = 'cpu',
prefix: Optional[str] = None,
relax_thresh: float = 0.05):
super().__init__(collect_device=collect_device, prefix=prefix)
self.full_score_weight = full_score_weight
self.relax_thresh = relax_thresh

def is_digit(self, x: str):
a = bool(re.match(r'^[+-]?\d+\.\d+$', x))
b = str(x).isnumeric()
return any([a, b])

def process(self, data_batch, data_samples):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for sample in data_samples:
gt_answer = sample.get('gt_answer')
sub_set = sample.get('sub_set')

is_digit = self.is_digit(gt_answer)

result = {
'pred_answer': sample.get('pred_answer'),
'gt_answer': gt_answer,
'is_digit': is_digit,
'sub_set': sub_set
}

self.results.append(result)

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
ChartQA_H_acc = []
ChartQA_M_acc = []
for result in results:
pred_answer = self._process_answer(result['pred_answer'])
gt_answer = result['gt_answer']
is_digit = result['is_digit']
sub_set = result['sub_set']

if is_digit:
if self.is_digit(pred_answer):
pred_answer = float(pred_answer)
gt_answer = float(gt_answer)
upper_bound = \
max(gt_answer - gt_answer * self.relax_thresh,
gt_answer + gt_answer * self.relax_thresh)
lower_bound = \
min(gt_answer - gt_answer * self.relax_thresh,
gt_answer + gt_answer * self.relax_thresh)
chart_acc = float(
all([
pred_answer >= lower_bound,
pred_answer <= upper_bound
]))
else:
chart_acc = 0.0
else:
chart_acc = float(pred_answer == gt_answer)

if sub_set == 'ChartQA-H':
ChartQA_H_acc.append(chart_acc)
elif sub_set == 'ChartQA-M':
ChartQA_M_acc.append(chart_acc)
else:
raise ValueError(f'Do not support to subset {sub_set}.')

ChartQA_H_acc = sum(ChartQA_H_acc) / len(ChartQA_H_acc) * 100
ChartQA_M_acc = sum(ChartQA_M_acc) / len(ChartQA_M_acc) * 100

accuracy = (ChartQA_H_acc + ChartQA_M_acc) / 2

metrics = {
'ChartQA-H acc': ChartQA_H_acc,
'ChartQA-M acc': ChartQA_M_acc,
'Overall acc': accuracy
}

return metrics

def _process_answer(self, answer):
answer = answer.replace('\n', ' ')
answer = answer.replace('\t', ' ')
answer = answer.strip()
answer = _process_punctuation(answer)
answer = _process_digit_article(answer)
return answer

0 comments on commit c20197f

Please sign in to comment.