Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support ubody dataset (2d keypoints) #2588

Merged
merged 11 commits into from
Aug 7, 2023
1,153 changes: 1,153 additions & 0 deletions configs/_base_/datasets/ubody2d.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=210, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=5e-4,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# hooks
default_hooks = dict(
checkpoint=dict(save_best='coco-wholebody/AP', rule='greater'))

# codec settings
codec = dict(
type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2)

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256))),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth'),
),
head=dict(
type='HeatmapHead',
in_channels=32,
out_channels=133,
deconv_out_channels=None,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec),
test_cfg=dict(
flip_test=True,
flip_mode='heatmap',
shift_heatmap=True,
))

# base dataset settings
dataset_type = 'UBody2dDataset'
data_mode = 'topdown'
data_root = 'data/UBody/'

scenes = [
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
]

train_datasets = []

for scene in scenes:
train_dataset = dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file=f'annotations/{scene}/train_annotations.json',
data_prefix=dict(img='images/'),
pipeline=[],
sample_interval=10)
train_datasets.append(train_dataset)

# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(type='RandomBBoxTransform'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

# data loaders
train_dataloader = dict(
batch_size=64,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='CombinedDataset',
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
datasets=train_datasets,
pipeline=train_pipeline,
test_mode=False,
))
val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type='UBody2dDataset',
ann_file=data_root + 'annotations/val_annotations.json',
data_prefix=dict(img=data_root + 'images/'),
pipeline=val_pipeline,
test_mode=True))
test_dataloader = val_dataloader

val_evaluator = dict(
type='CocoWholeBodyMetric',
use_area=False,
ann_file=data_root + 'annotations/val_annotations.json')
test_evaluator = val_evaluator
82 changes: 82 additions & 0 deletions docs/en/dataset_zoo/2d_wholebody_keypoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,85 @@ mmpose
Please also install the latest version of [Extended COCO API](https://github.com/jin-s13/xtcocoapi) (version>=1.5) to support Halpe evaluation:

`pip install xtcocotools`

## UBody

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2303.16160">UBody (CVPR'2023)</a></summary>

```bibtex
@article{lin2023one,
title={One-Stage 3D Whole-Body Mesh Recovery with Component Aware Transformer},
author={Lin, Jing and Zeng, Ailing and Wang, Haoqian and Zhang, Lei and Li, Yu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023},
}
```

</details>

<div align="center">
<img src="https://github.com/IDEA-Research/OSX/blob/main/assets/demo_video.gif" height="300px">
</div>

For [Ubody](https://github.com/IDEA-Research/OSX) dataset, videos and annotations can be downloaded from [OSX homepage](https://github.com/IDEA-Research/OSX).

Download and extract them under $MMPOSE/data, and make them look like this:

```text
mmpose
├── mmpose
├── docs
├── tests
├── tools
├── configs
`── data
│── UBody
├── annotations
│   ├── ConductMusic
│   ├── Entertainment
│   ├── Fitness
│   ├── Interview
│   ├── LiveVlog
│   ├── Magic_show
│   ├── Movie
│   ├── Olympic
│   ├── Online_class
│   ├── SignLanguage
│   ├── Singing
│   ├── Speech
│   ├── TVShow
│   ├── TalkShow
│   └── VideoConference
├── splits
│   ├── inter_scene_test_list.npy
│   └── intra_scene_test_list.npy
├── videos
│   ├── ConductMusic
│   ├── Entertainment
│   ├── Fitness
│   ├── Interview
│   ├── LiveVlog
│   ├── Magic_show
│   ├── Movie
│   ├── Olympic
│   ├── Online_class
│   ├── SignLanguage
│   ├── Singing
│   ├── Speech
│   ├── TVShow
│   ├── TalkShow
│   └── VideoConference
```

Convert videos to images then split them into train/val set:

```shell
python tools/dataset_converters/ubody_kpts_to_coco.py
```

Please also install the latest version of [Extended COCO API](https://github.com/jin-s13/xtcocoapi) (version>=1.5) to support COCO-WholeBody evaluation:

`pip install xtcocotools`
8 changes: 7 additions & 1 deletion mmpose/datasets/datasets/base/base_coco_style_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BaseCocoStyleDataset(BaseDataset):
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Default: 1000.
sample_interval (int, optional): The sample interval of the dataset.
Default: 1.
"""

METAINFO: dict = dict()
Expand All @@ -73,7 +75,8 @@ def __init__(self,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
lazy_init: bool = False,
max_refetch: int = 1000):
max_refetch: int = 1000,
sample_interval: int = 1):

if data_mode not in {'topdown', 'bottomup'}:
raise ValueError(
Expand All @@ -94,6 +97,7 @@ def __init__(self,
'while "bbox_file" is only '
'supported when `test_mode==True`.')
self.bbox_file = bbox_file
self.sample_interval = sample_interval

super().__init__(
ann_file=ann_file,
Expand Down Expand Up @@ -207,6 +211,8 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
image_list = []

for img_id in self.coco.getImgIds():
if img_id % self.sample_interval != 0:
continue
img = self.coco.loadImgs(img_id)[0]
img.update({
'img_id':
Expand Down
3 changes: 2 additions & 1 deletion mmpose/datasets/datasets/wholebody/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .coco_wholebody_dataset import CocoWholeBodyDataset
from .halpe_dataset import HalpeDataset
from .ubody2d_dataset import UBody2dDataset

__all__ = ['CocoWholeBodyDataset', 'HalpeDataset']
__all__ = ['CocoWholeBodyDataset', 'HalpeDataset', 'UBody2dDataset']
61 changes: 61 additions & 0 deletions mmpose/datasets/datasets/wholebody/ubody2d_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmpose.registry import DATASETS
from .coco_wholebody_dataset import CocoWholeBodyDataset


@DATASETS.register_module()
class UBody2dDataset(CocoWholeBodyDataset):
"""Ubody2d dataset for pose estimation.

"One-Stage 3D Whole-Body Mesh Recovery with Component Aware Transformer",
CVPR'2023. More details can be found in the `paper
<https://arxiv.org/abs/2303.16160>`__ .

Ubody2D keypoints::

0-16: 17 body keypoints,
17-22: 6 foot keypoints,
23-90: 68 face keypoints,
91-132: 42 hand keypoints

In total, we have 133 keypoints for wholebody pose estimation.

Args:
ann_file (str): Annotation file path. Default: ''.
bbox_file (str, optional): Detection result file path. If
``bbox_file`` is set, detected bboxes loaded from this file will
be used instead of ground-truth bboxes. This setting is only for
evaluation, i.e., ignored when ``test_mode`` is ``False``.
Default: ``None``.
data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
``'bottomup'``. In ``'topdown'`` mode, each data sample contains
one instance; while in ``'bottomup'`` mode, each data sample
contains all instances in a image. Default: ``'topdown'``
metainfo (dict, optional): Meta information for dataset, such as class
information. Default: ``None``.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Default: ``None``.
data_prefix (dict, optional): Prefix for training data. Default:
``dict(img=None, ann=None)``.
filter_cfg (dict, optional): Config for filter data. Default: `None`.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
dataset. Default: ``None`` which means using all ``data_infos``.
serialize_data (bool, optional): Whether to hold memory using
serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy.
Default: ``True``.
pipeline (list, optional): Processing pipeline. Default: [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Default: ``False``.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=False``. Default: ``False``.
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Default: 1000.
"""
Tau-J marked this conversation as resolved.
Show resolved Hide resolved

METAINFO: dict = dict(from_file='configs/_base_/datasets/ubody2d.py')
Loading