Skip to content

Commit

Permalink
[Feature] Support ubody dataset (2d keypoints) (#2588)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch authored Aug 7, 2023
1 parent a2769aa commit cb48094
Show file tree
Hide file tree
Showing 14 changed files with 1,797 additions and 9 deletions.
1,153 changes: 1,153 additions & 0 deletions configs/_base_/datasets/ubody2d.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions configs/wholebody_2d_keypoint/topdown_heatmap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,18 @@ Results on COCO-WholeBody v1.0 val with detector having human AP of 56.4 on COCO
| HRNet-w32+Dark | 256x192 | 0.582 | 0.671 | [hrnet_dark_coco-wholebody.md](./coco-wholebody/hrnet_dark_coco-wholebody.md) |
| HRNet-w48 | 256x192 | 0.579 | 0.681 | [hrnet_coco-wholebody.md](./coco-wholebody/hrnet_coco-wholebody.md) |
| CSPNeXt-m | 256x192 | 0.567 | 0.641 | [cspnext_udp_coco-wholebody.md](./coco-wholebody/cspnext_udp_coco-wholebody.md) |
| HRNet-w32 | 256x192 | 0.549 | 0.646 | [hrnet_ubody-coco-wholebody.md](./ubody2d/hrnet_ubody-coco-wholebody.md) |
| ResNet-152 | 256x192 | 0.548 | 0.661 | [resnet_coco-wholebody.md](./coco-wholebody/resnet_coco-wholebody.md) |
| HRNet-w32 | 256x192 | 0.536 | 0.636 | [hrnet_coco-wholebody.md](./coco-wholebody/hrnet_coco-wholebody.md) |
| ResNet-101 | 256x192 | 0.531 | 0.645 | [resnet_coco-wholebody.md](./coco-wholebody/resnet_coco-wholebody.md) |
| S-ViPNAS-Res50+Dark | 256x192 | 0.528 | 0.632 | [vipnas_dark_coco-wholebody.md](./coco-wholebody/vipnas_dark_coco-wholebody.md) |
| ResNet-50 | 256x192 | 0.521 | 0.633 | [resnet_coco-wholebody.md](./coco-wholebody/resnet_coco-wholebody.md) |
| S-ViPNAS-Res50 | 256x192 | 0.495 | 0.607 | [vipnas_coco-wholebody.md](./coco-wholebody/vipnas_coco-wholebody.md) |

### UBody2D Dataset

Result on UBody val set, computed with gt keypoints.

| Model | Input Size | Whole AP | Whole AR | Details and Download |
| :-------: | :--------: | :------: | :------: | :----------------------------------------------------------------------: |
| HRNet-w32 | 256x192 | 0.690 | 0.729 | [hrnet_ubody-coco-wholebody.md](./ubody2d/hrnet_ubody-coco-wholebody.md) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Models:
- Config: configs/wholebody_2d_keypoint/topdown_heatmap/ubody2d/td-hm_hrnet-w32_8xb64-210e_ubody-256x192.py
In Collection: HRNet
Metadata:
Architecture: &id001
- HRNet
Training Data: UBody-COCO-WholeBody
Name: td-hm_hrnet-w32_8xb64-210e_ubody-256x192
Results:
- Dataset: COCO-WholeBody
Metrics:
Body AP: 0.678
Body AR: 0.755
Face AP: 0.630
Face AR: 0.708
Foot AP: 0.543
Foot AR: 0.661
Hand AP: 0.467
Hand AR: 0.566
Whole AP: 0.536
Whole AR: 0.636
Task: Wholebody 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/wholebody_2d_keypoint/ubody/td-hm_hrnet-w32_8xb64-210e_ubody-coco-256x192-7c227391_20230807.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_CVPR_2019/html/Sun_Deep_High-Resolution_Representation_Learning_for_Human_Pose_Estimation_CVPR_2019_paper.html">HRNet (CVPR'2019)</a></summary>

```bibtex
@inproceedings{sun2019deep,
title={Deep high-resolution representation learning for human pose estimation},
author={Sun, Ke and Xiao, Bin and Liu, Dong and Wang, Jingdong},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={5693--5703},
year={2019}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2303.16160">UBody (CVPR'2023)</a></summary>

```bibtex
@article{lin2023one,
title={One-Stage 3D Whole-Body Mesh Recovery with Component Aware Transformer},
author={Lin, Jing and Zeng, Ailing and Wang, Haoqian and Zhang, Lei and Li, Yu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023},
}
```

</details>

Results on COCO-WholeBody v1.0 val with detector having human AP of 56.4 on COCO val2017 dataset

| Arch | Input Size | Body AP | Body AR | Foot AP | Foot AR | Face AP | Face AR | Hand AP | Hand AR | Whole AP | Whole AR | ckpt | log |
| :-------------------------------------- | :--------: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :------: | :------: | :--------------------------------------: | :-------------------------------------: |
| [pose_hrnet_w32](/configs/wholebody_2d_keypoint/topdown_heatmap/ubody/td-hm_hrnet-w32_8xb64-210e_coco-wholebody-256x192.py) | 256x192 | 0.685 | 0.759 | 0.564 | 0.675 | 0.625 | 0.705 | 0.516 | 0.609 | 0.549 | 0.646 | [ckpt](https://download.openmmlab.com/mmpose/v1/wholebody_2d_keypoint/ubody/td-hm_hrnet-w32_8xb64-210e_ubody-coco-256x192-7c227391_20230807.pth) | [log](https://download.openmmlab.com/mmpose/v1/wholebody_2d_keypoint/ubody/td-hm_hrnet-w32_8xb64-210e_ubody-coco-256x192-7c227391_20230807.json) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=210, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=5e-4,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# hooks
default_hooks = dict(
checkpoint=dict(save_best='coco-wholebody/AP', rule='greater'))

# codec settings
codec = dict(
type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2)

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256))),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth'),
),
head=dict(
type='HeatmapHead',
in_channels=32,
out_channels=133,
deconv_out_channels=None,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec),
test_cfg=dict(
flip_test=True,
flip_mode='heatmap',
shift_heatmap=True,
))

# base dataset settings
dataset_type = 'UBody2dDataset'
data_mode = 'topdown'
data_root = 'data/UBody/'

scenes = [
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
]

train_datasets = [
dict(
type='CocoWholeBodyDataset',
data_root='data/coco/',
data_mode=data_mode,
ann_file='annotations/coco_wholebody_train_v1.0.json',
data_prefix=dict(img='train2017/'),
pipeline=[])
]

for scene in scenes:
train_dataset = dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file=f'annotations/{scene}/train_annotations.json',
data_prefix=dict(img='images/'),
pipeline=[],
sample_interval=10)
train_datasets.append(train_dataset)

# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(type='RandomBBoxTransform'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

# data loaders
train_dataloader = dict(
batch_size=64,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='CombinedDataset',
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
datasets=train_datasets,
pipeline=train_pipeline,
test_mode=False,
))
val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type='CocoWholeBodyDataset',
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
data_prefix=dict(img='data/coco/val2017/'),
pipeline=val_pipeline,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
test_mode=True))
test_dataloader = val_dataloader

val_evaluator = dict(
type='CocoWholeBodyMetric',
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
test_evaluator = val_evaluator
82 changes: 82 additions & 0 deletions docs/en/dataset_zoo/2d_wholebody_keypoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,85 @@ mmpose
Please also install the latest version of [Extended COCO API](https://github.com/jin-s13/xtcocoapi) (version>=1.5) to support Halpe evaluation:

`pip install xtcocotools`

## UBody

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2303.16160">UBody (CVPR'2023)</a></summary>

```bibtex
@article{lin2023one,
title={One-Stage 3D Whole-Body Mesh Recovery with Component Aware Transformer},
author={Lin, Jing and Zeng, Ailing and Wang, Haoqian and Zhang, Lei and Li, Yu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023},
}
```

</details>

<div align="center">
<img src="https://github.com/open-mmlab/mmpose/assets/15952744/0c97e43a-46a9-46a3-a5dd-b84bf9d6d6f2" height="300px">
</div>

For [Ubody](https://github.com/IDEA-Research/OSX) dataset, videos and annotations can be downloaded from [OSX homepage](https://github.com/IDEA-Research/OSX).

Download and extract them under $MMPOSE/data, and make them look like this:

```text
mmpose
├── mmpose
├── docs
├── tests
├── tools
├── configs
`── data
│── UBody
├── annotations
│   ├── ConductMusic
│   ├── Entertainment
│   ├── Fitness
│   ├── Interview
│   ├── LiveVlog
│   ├── Magic_show
│   ├── Movie
│   ├── Olympic
│   ├── Online_class
│   ├── SignLanguage
│   ├── Singing
│   ├── Speech
│   ├── TVShow
│   ├── TalkShow
│   └── VideoConference
├── splits
│   ├── inter_scene_test_list.npy
│   └── intra_scene_test_list.npy
├── videos
│   ├── ConductMusic
│   ├── Entertainment
│   ├── Fitness
│   ├── Interview
│   ├── LiveVlog
│   ├── Magic_show
│   ├── Movie
│   ├── Olympic
│   ├── Online_class
│   ├── SignLanguage
│   ├── Singing
│   ├── Speech
│   ├── TVShow
│   ├── TalkShow
│   └── VideoConference
```

Convert videos to images then split them into train/val set:

```shell
python tools/dataset_converters/ubody_kpts_to_coco.py
```

Please also install the latest version of [Extended COCO API](https://github.com/jin-s13/xtcocoapi) (version>=1.5) to support COCO-WholeBody evaluation:

`pip install xtcocotools`
Loading

0 comments on commit cb48094

Please sign in to comment.