Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support LaPa Dataset #2281

Merged
merged 15 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
688 changes: 688 additions & 0 deletions configs/_base_/datasets/lapa.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions configs/face_2d_keypoint/rtmpose/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ Results on WFLW dataset
| Model | Input Size | NME | Details and Download |
| :-------: | :--------: | :--: | :---------------------------------------: |
| RTMPose-m | 256x256 | 4.01 | [rtmpose_wflw.md](./wflw/rtmpose_wflw.md) |

### LaPa Dataset

Results on LaPa dataset

| Model | Input Size | NME | Details and Download |
| :-------: | :--------: | :--: | :---------------------------------------: |
| RTMPose-m | 256x256 | 1.29 | [rtmpose_lapa.md](./wflw/rtmpose_lapa.md) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
max_epochs = 120
stage2_num_epochs = 10
base_lr = 4e-3

train_cfg = dict(max_epochs=max_epochs, val_interval=1)
randomness = dict(seed=21)

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=1.0e-5,
by_epoch=False,
begin=0,
end=1000),
dict(
# use cosine lr from 150 to 300 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# codec settings
codec = dict(
type='SimCCLabel',
input_size=(256, 256),
sigma=(5.66, 5.66),
simcc_split_ratio=2.0,
normalize=False,
use_dark=False)

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
_scope_='mmdet',
type='CSPNeXt',
arch='P5',
expand_ratio=0.5,
deepen_factor=0.67,
widen_factor=0.75,
out_indices=(4, ),
channel_attention=True,
norm_cfg=dict(type='SyncBN'),
act_cfg=dict(type='SiLU'),
init_cfg=dict(
type='Pretrained',
prefix='backbone.',
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
'rtmposev1/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth' # noqa
)),
head=dict(
type='RTMCCHead',
in_channels=768,
out_channels=106,
input_size=codec['input_size'],
in_featuremap_size=(8, 8),
simcc_split_ratio=codec['simcc_split_ratio'],
final_layer_kernel_size=7,
gau_cfg=dict(
hidden_dims=256,
s=128,
expansion_factor=2,
dropout_rate=0.,
drop_path=0.,
act_fn='SiLU',
use_rel_bias=False,
pos_enc=False),
loss=dict(
type='KLDiscretLoss',
use_target_weight=True,
beta=10.,
label_softmax=True),
decoder=codec),
test_cfg=dict(flip_test=True, ))

# base dataset settings
dataset_type = 'LapaDataset'
data_mode = 'topdown'
data_root = 'data/LaPa/'

backend_args = dict(backend='local')
# backend_args = dict(
# backend='petrel',
# path_mapping=dict({
# f'{data_root}': 's3://openmmlab/datasets/pose/LaPa/',
# f'{data_root}': 's3://openmmlab/datasets/pose/LaPa/'
# }))

# pipelines
train_pipeline = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(
type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=80),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='PhotometricDistortion'),
dict(
type='Albumentation',
transforms=[
dict(type='Blur', p=0.2),
dict(type='MedianBlur', p=0.2),
dict(
type='CoarseDropout',
max_holes=1,
max_height=0.4,
max_width=0.4,
min_holes=1,
min_height=0.2,
min_width=0.2,
p=1.0),
]),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

train_pipeline_stage2 = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
# dict(type='RandomHalfBody'),
dict(
type='RandomBBoxTransform',
shift_factor=0.,
scale_factor=[0.75, 1.25],
rotate_factor=60),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(
type='Albumentation',
transforms=[
dict(type='Blur', p=0.1),
dict(type='MedianBlur', p=0.1),
dict(
type='CoarseDropout',
max_holes=1,
max_height=0.4,
max_width=0.4,
min_holes=1,
min_height=0.2,
min_width=0.2,
p=0.5),
]),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]

# data loaders
train_dataloader = dict(
batch_size=32,
num_workers=10,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/lapa_train.json',
data_prefix=dict(img='train/images/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=32,
num_workers=10,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/lapa_val.json',
data_prefix=dict(img='val/images/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = dict(
batch_size=32,
num_workers=10,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/lapa_test.json',
data_prefix=dict(img='test/images/'),
test_mode=True,
pipeline=val_pipeline,
))

# hooks
default_hooks = dict(
checkpoint=dict(
save_best='NME', rule='less', max_keep_ckpts=1, interval=1))

custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - stage2_num_epochs,
switch_pipeline=train_pipeline_stage2)
]

# evaluators
val_evaluator = dict(
type='NME',
norm_mode='keypoint_distance',
)
test_evaluator = val_evaluator
40 changes: 40 additions & 0 deletions configs/face_2d_keypoint/rtmpose/lapa/rtmpose_lapa.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2212.07784">RTMDet (ArXiv 2022)</a></summary>

```bibtex
@misc{lyu2022rtmdet,
title={RTMDet: An Empirical Study of Designing Real-Time Object Detectors},
author={Chengqi Lyu and Wenwei Zhang and Haian Huang and Yue Zhou and Yudong Wang and Yanyi Liu and Shilong Zhang and Kai Chen},
year={2022},
eprint={2212.07784},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://aaai.org/ojs/index.php/AAAI/article/view/6832/6686">LaPa (AAAI'2020)</a></summary>

```bibtex
@inproceedings{liu2020new,
title={A New Dataset and Boundary-Attention Semantic Segmentation for Face Parsing.},
author={Liu, Yinglu and Shi, Hailin and Shen, Hao and Si, Yue and Wang, Xiaobo and Mei, Tao},
booktitle={AAAI},
pages={11637--11644},
year={2020}
}
```

</details>

Results on COCO-WholeBody-Face val set

| Arch | Input Size | NME | ckpt | log |
| :------------------------------------------------------------- | :--------: | :--: | :------------------------------------------------------------: | :------------------------------------------------------------: |
| [pose_rtmpose_m](/configs/face_2d_keypoint/rtmpose/lapa/rtmpose-m_8xb64-120e_lapa-256x256.py) | 256x256 | 1.29 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-lapa_pt-aic-coco_120e-256x256-762b1ae2_20230422.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-lapa_pt-aic-coco_120e-256x256-762b1ae2_20230422.json) |
15 changes: 15 additions & 0 deletions configs/face_2d_keypoint/rtmpose/lapa/rtmpose_lapa.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Models:
- Config: configs/face_2d_keypoint/rtmpose/lapa/rtmpose-m_8xb64-120e_lapa-256x256.py
In Collection: RTMPose
Alias: face
Metadata:
Architecture:
- RTMPose
Training Data: LaPa
Name: rtmpose-m_8xb64-120e_lapa-256x256
Results:
- Dataset: WFLW
Metrics:
NME: 1.29
Task: Face 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-lapa_pt-aic-coco_120e-256x256-762b1ae2_20230422.pth
3 changes: 2 additions & 1 deletion demo/topdown_demo_with_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def main():

if output_file:
img_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)
if args.show:
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)

elif input_type in ['webcam', 'video']:
from mmpose.visualization import FastVisualizer
Expand Down
56 changes: 56 additions & 0 deletions docs/en/dataset_zoo/2d_face_keypoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ MMPose supported datasets:
- [AFLW](#aflw-dataset) \[ [Homepage](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/) \]
- [COFW](#cofw-dataset) \[ [Homepage](http://www.vision.caltech.edu/xpburgos/ICCV13/) \]
- [COCO-WholeBody-Face](#coco-wholebody-face) \[ [Homepage](https://github.com/jin-s13/COCO-WholeBody/) \]
- [LaPa](#lapa-dataset) \[ [Homepage](https://github.com/JDAI-CV/lapa-dataset) \]

## 300W Dataset

Expand Down Expand Up @@ -325,3 +326,58 @@ mmpose
Please also install the latest version of [Extended COCO API](https://github.com/jin-s13/xtcocoapi) to support COCO-WholeBody evaluation:

`pip install xtcocotools`

## LaPa

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://aaai.org/ojs/index.php/AAAI/article/view/6832/6686">LaPa (AAAI'2020)</a></summary>

```bibtex
@inproceedings{liu2020new,
title={A New Dataset and Boundary-Attention Semantic Segmentation for Face Parsing.},
author={Liu, Yinglu and Shi, Hailin and Shen, Hao and Si, Yue and Wang, Xiaobo and Mei, Tao},
booktitle={AAAI},
pages={11637--11644},
year={2020}
}
```

</details>

<div align="center">
<img src="https://github.com/lucia123/lapa-dataset/raw/master/sample.png" height="200px">
</div>

For [LaPa](https://github.com/JDAI-CV/lapa-dataset) dataset, images can be downloaded from [their github page](https://github.com/JDAI-CV/lapa-dataset).

Download and extract them under $MMPOSE/data, and use our `tools/dataset_converters/lapa2coco.py` to make them look like this:

```text
mmpose
├── mmpose
├── docs
├── tests
├── tools
├── configs
`── data
│── LaPa
│-- annotations
│ │-- lapa_train.json
│ |-- lapa_val.json
│ |-- lapa_test.json
│-- train
│ │-- images
│ │-- labels
│ │-- landmarks
│-- val
│ │-- images
│ │-- labels
│ │-- landmarks
`-- test
│ │-- images
│ │-- labels
│ │-- landmarks
```
Loading