Skip to content

Commit

Permalink
[Feature] Support LaPa Dataset (open-mmlab#2281)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J authored Apr 23, 2023
1 parent 0e7737d commit 7a327b3
Show file tree
Hide file tree
Showing 19 changed files with 1,543 additions and 70 deletions.
688 changes: 688 additions & 0 deletions configs/_base_/datasets/lapa.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions configs/face_2d_keypoint/rtmpose/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ Results on WFLW dataset
| Model | Input Size | NME | Details and Download |
| :-------: | :--------: | :--: | :---------------------------------------: |
| RTMPose-m | 256x256 | 4.01 | [rtmpose_wflw.md](./wflw/rtmpose_wflw.md) |

### LaPa Dataset

Results on LaPa dataset

| Model | Input Size | NME | Details and Download |
| :-------: | :--------: | :--: | :---------------------------------------: |
| RTMPose-m | 256x256 | 1.29 | [rtmpose_lapa.md](./wflw/rtmpose_lapa.md) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
max_epochs = 120
stage2_num_epochs = 10
base_lr = 4e-3

train_cfg = dict(max_epochs=max_epochs, val_interval=1)
randomness = dict(seed=21)

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=1.0e-5,
by_epoch=False,
begin=0,
end=1000),
dict(
# use cosine lr from 150 to 300 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# codec settings
codec = dict(
type='SimCCLabel',
input_size=(256, 256),
sigma=(5.66, 5.66),
simcc_split_ratio=2.0,
normalize=False,
use_dark=False)

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
_scope_='mmdet',
type='CSPNeXt',
arch='P5',
expand_ratio=0.5,
deepen_factor=0.67,
widen_factor=0.75,
out_indices=(4, ),
channel_attention=True,
norm_cfg=dict(type='SyncBN'),
act_cfg=dict(type='SiLU'),
init_cfg=dict(
type='Pretrained',
prefix='backbone.',
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
'rtmposev1/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth' # noqa
)),
head=dict(
type='RTMCCHead',
in_channels=768,
out_channels=106,
input_size=codec['input_size'],
in_featuremap_size=(8, 8),
simcc_split_ratio=codec['simcc_split_ratio'],
final_layer_kernel_size=7,
gau_cfg=dict(
hidden_dims=256,
s=128,
expansion_factor=2,
dropout_rate=0.,
drop_path=0.,
act_fn='SiLU',
use_rel_bias=False,
pos_enc=False),
loss=dict(
type='KLDiscretLoss',
use_target_weight=True,
beta=10.,
label_softmax=True),
decoder=codec),
test_cfg=dict(flip_test=True, ))

# base dataset settings
dataset_type = 'LapaDataset'
data_mode = 'topdown'
data_root = 'data/LaPa/'

backend_args = dict(backend='local')
# backend_args = dict(
# backend='petrel',
# path_mapping=dict({
# f'{data_root}': 's3://openmmlab/datasets/pose/LaPa/',
# f'{data_root}': 's3://openmmlab/datasets/pose/LaPa/'
# }))

# pipelines
train_pipeline = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(
type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=80),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='PhotometricDistortion'),
dict(
type='Albumentation',
transforms=[
dict(type='Blur', p=0.2),
dict(type='MedianBlur', p=0.2),
dict(
type='CoarseDropout',
max_holes=1,
max_height=0.4,
max_width=0.4,
min_holes=1,
min_height=0.2,
min_width=0.2,
p=1.0),
]),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

train_pipeline_stage2 = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
# dict(type='RandomHalfBody'),
dict(
type='RandomBBoxTransform',
shift_factor=0.,
scale_factor=[0.75, 1.25],
rotate_factor=60),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(
type='Albumentation',
transforms=[
dict(type='Blur', p=0.1),
dict(type='MedianBlur', p=0.1),
dict(
type='CoarseDropout',
max_holes=1,
max_height=0.4,
max_width=0.4,
min_holes=1,
min_height=0.2,
min_width=0.2,
p=0.5),
]),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]

# data loaders
train_dataloader = dict(
batch_size=32,
num_workers=10,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/lapa_train.json',
data_prefix=dict(img='train/images/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=32,
num_workers=10,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/lapa_val.json',
data_prefix=dict(img='val/images/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = dict(
batch_size=32,
num_workers=10,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/lapa_test.json',
data_prefix=dict(img='test/images/'),
test_mode=True,
pipeline=val_pipeline,
))

# hooks
default_hooks = dict(
checkpoint=dict(
save_best='NME', rule='less', max_keep_ckpts=1, interval=1))

custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - stage2_num_epochs,
switch_pipeline=train_pipeline_stage2)
]

# evaluators
val_evaluator = dict(
type='NME',
norm_mode='keypoint_distance',
)
test_evaluator = val_evaluator
40 changes: 40 additions & 0 deletions configs/face_2d_keypoint/rtmpose/lapa/rtmpose_lapa.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2212.07784">RTMDet (ArXiv 2022)</a></summary>

```bibtex
@misc{lyu2022rtmdet,
title={RTMDet: An Empirical Study of Designing Real-Time Object Detectors},
author={Chengqi Lyu and Wenwei Zhang and Haian Huang and Yue Zhou and Yudong Wang and Yanyi Liu and Shilong Zhang and Kai Chen},
year={2022},
eprint={2212.07784},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://aaai.org/ojs/index.php/AAAI/article/view/6832/6686">LaPa (AAAI'2020)</a></summary>

```bibtex
@inproceedings{liu2020new,
title={A New Dataset and Boundary-Attention Semantic Segmentation for Face Parsing.},
author={Liu, Yinglu and Shi, Hailin and Shen, Hao and Si, Yue and Wang, Xiaobo and Mei, Tao},
booktitle={AAAI},
pages={11637--11644},
year={2020}
}
```

</details>

Results on COCO-WholeBody-Face val set

| Arch | Input Size | NME | ckpt | log |
| :------------------------------------------------------------- | :--------: | :--: | :------------------------------------------------------------: | :------------------------------------------------------------: |
| [pose_rtmpose_m](/configs/face_2d_keypoint/rtmpose/lapa/rtmpose-m_8xb64-120e_lapa-256x256.py) | 256x256 | 1.29 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-lapa_pt-aic-coco_120e-256x256-762b1ae2_20230422.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-lapa_pt-aic-coco_120e-256x256-762b1ae2_20230422.json) |
15 changes: 15 additions & 0 deletions configs/face_2d_keypoint/rtmpose/lapa/rtmpose_lapa.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Models:
- Config: configs/face_2d_keypoint/rtmpose/lapa/rtmpose-m_8xb64-120e_lapa-256x256.py
In Collection: RTMPose
Alias: face
Metadata:
Architecture:
- RTMPose
Training Data: LaPa
Name: rtmpose-m_8xb64-120e_lapa-256x256
Results:
- Dataset: WFLW
Metrics:
NME: 1.29
Task: Face 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-lapa_pt-aic-coco_120e-256x256-762b1ae2_20230422.pth
3 changes: 2 additions & 1 deletion demo/topdown_demo_with_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def main():

if output_file:
img_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)
if args.show:
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)

elif input_type in ['webcam', 'video']:
from mmpose.visualization import FastVisualizer
Expand Down
56 changes: 56 additions & 0 deletions docs/en/dataset_zoo/2d_face_keypoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ MMPose supported datasets:
- [AFLW](#aflw-dataset) \[ [Homepage](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/) \]
- [COFW](#cofw-dataset) \[ [Homepage](http://www.vision.caltech.edu/xpburgos/ICCV13/) \]
- [COCO-WholeBody-Face](#coco-wholebody-face) \[ [Homepage](https://github.com/jin-s13/COCO-WholeBody/) \]
- [LaPa](#lapa-dataset) \[ [Homepage](https://github.com/JDAI-CV/lapa-dataset) \]

## 300W Dataset

Expand Down Expand Up @@ -325,3 +326,58 @@ mmpose
Please also install the latest version of [Extended COCO API](https://github.com/jin-s13/xtcocoapi) to support COCO-WholeBody evaluation:

`pip install xtcocotools`

## LaPa

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://aaai.org/ojs/index.php/AAAI/article/view/6832/6686">LaPa (AAAI'2020)</a></summary>

```bibtex
@inproceedings{liu2020new,
title={A New Dataset and Boundary-Attention Semantic Segmentation for Face Parsing.},
author={Liu, Yinglu and Shi, Hailin and Shen, Hao and Si, Yue and Wang, Xiaobo and Mei, Tao},
booktitle={AAAI},
pages={11637--11644},
year={2020}
}
```

</details>

<div align="center">
<img src="https://github.com/lucia123/lapa-dataset/raw/master/sample.png" height="200px">
</div>

For [LaPa](https://github.com/JDAI-CV/lapa-dataset) dataset, images can be downloaded from [their github page](https://github.com/JDAI-CV/lapa-dataset).

Download and extract them under $MMPOSE/data, and use our `tools/dataset_converters/lapa2coco.py` to make them look like this:

```text
mmpose
├── mmpose
├── docs
├── tests
├── tools
├── configs
`── data
│── LaPa
│-- annotations
│ │-- lapa_train.json
│ |-- lapa_val.json
│ |-- lapa_test.json
│-- train
│ │-- images
│ │-- labels
│ │-- landmarks
│-- val
│ │-- images
│ │-- labels
│ │-- landmarks
`-- test
│ │-- images
│ │-- labels
│ │-- landmarks
```
Loading

0 comments on commit 7a327b3

Please sign in to comment.