-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support ubody dataset (2d keypoints) (#2588)
- Loading branch information
Showing
14 changed files
with
1,797 additions
and
9 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
configs/wholebody_2d_keypoint/topdown_heatmap/ubody2d/hrnet_coco-wholebody.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
Models: | ||
- Config: configs/wholebody_2d_keypoint/topdown_heatmap/ubody2d/td-hm_hrnet-w32_8xb64-210e_ubody-256x192.py | ||
In Collection: HRNet | ||
Metadata: | ||
Architecture: &id001 | ||
- HRNet | ||
Training Data: UBody-COCO-WholeBody | ||
Name: td-hm_hrnet-w32_8xb64-210e_ubody-256x192 | ||
Results: | ||
- Dataset: COCO-WholeBody | ||
Metrics: | ||
Body AP: 0.678 | ||
Body AR: 0.755 | ||
Face AP: 0.630 | ||
Face AR: 0.708 | ||
Foot AP: 0.543 | ||
Foot AR: 0.661 | ||
Hand AP: 0.467 | ||
Hand AR: 0.566 | ||
Whole AP: 0.536 | ||
Whole AR: 0.636 | ||
Task: Wholebody 2D Keypoint | ||
Weights: https://download.openmmlab.com/mmpose/v1/wholebody_2d_keypoint/ubody/td-hm_hrnet-w32_8xb64-210e_ubody-coco-256x192-7c227391_20230807.pth |
38 changes: 38 additions & 0 deletions
38
...igs/wholebody_2d_keypoint/topdown_heatmap/ubody2d/hrnet_ubody-coco-wholebody.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
<!-- [ALGORITHM] --> | ||
|
||
<details> | ||
<summary align="right"><a href="http://openaccess.thecvf.com/content_CVPR_2019/html/Sun_Deep_High-Resolution_Representation_Learning_for_Human_Pose_Estimation_CVPR_2019_paper.html">HRNet (CVPR'2019)</a></summary> | ||
|
||
```bibtex | ||
@inproceedings{sun2019deep, | ||
title={Deep high-resolution representation learning for human pose estimation}, | ||
author={Sun, Ke and Xiao, Bin and Liu, Dong and Wang, Jingdong}, | ||
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, | ||
pages={5693--5703}, | ||
year={2019} | ||
} | ||
``` | ||
|
||
</details> | ||
|
||
<!-- [DATASET] --> | ||
|
||
<details> | ||
<summary align="right"><a href="https://arxiv.org/abs/2303.16160">UBody (CVPR'2023)</a></summary> | ||
|
||
```bibtex | ||
@article{lin2023one, | ||
title={One-Stage 3D Whole-Body Mesh Recovery with Component Aware Transformer}, | ||
author={Lin, Jing and Zeng, Ailing and Wang, Haoqian and Zhang, Lei and Li, Yu}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
year={2023}, | ||
} | ||
``` | ||
|
||
</details> | ||
|
||
Results on COCO-WholeBody v1.0 val with detector having human AP of 56.4 on COCO val2017 dataset | ||
|
||
| Arch | Input Size | Body AP | Body AR | Foot AP | Foot AR | Face AP | Face AR | Hand AP | Hand AR | Whole AP | Whole AR | ckpt | log | | ||
| :-------------------------------------- | :--------: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :------: | :------: | :--------------------------------------: | :-------------------------------------: | | ||
| [pose_hrnet_w32](/configs/wholebody_2d_keypoint/topdown_heatmap/ubody/td-hm_hrnet-w32_8xb64-210e_coco-wholebody-256x192.py) | 256x192 | 0.685 | 0.759 | 0.564 | 0.675 | 0.625 | 0.705 | 0.516 | 0.609 | 0.549 | 0.646 | [ckpt](https://download.openmmlab.com/mmpose/v1/wholebody_2d_keypoint/ubody/td-hm_hrnet-w32_8xb64-210e_ubody-coco-256x192-7c227391_20230807.pth) | [log](https://download.openmmlab.com/mmpose/v1/wholebody_2d_keypoint/ubody/td-hm_hrnet-w32_8xb64-210e_ubody-coco-256x192-7c227391_20230807.json) | |
173 changes: 173 additions & 0 deletions
173
...wholebody_2d_keypoint/topdown_heatmap/ubody2d/td-hm_hrnet-w32_8xb64-210e_ubody-256x192.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
_base_ = ['../../../_base_/default_runtime.py'] | ||
|
||
# runtime | ||
train_cfg = dict(max_epochs=210, val_interval=10) | ||
|
||
# optimizer | ||
optim_wrapper = dict(optimizer=dict( | ||
type='Adam', | ||
lr=5e-4, | ||
)) | ||
|
||
# learning policy | ||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', begin=0, end=500, start_factor=0.001, | ||
by_epoch=False), # warm-up | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=210, | ||
milestones=[170, 200], | ||
gamma=0.1, | ||
by_epoch=True) | ||
] | ||
|
||
# automatically scaling LR based on the actual training batch size | ||
auto_scale_lr = dict(base_batch_size=512) | ||
|
||
# hooks | ||
default_hooks = dict( | ||
checkpoint=dict(save_best='coco-wholebody/AP', rule='greater')) | ||
|
||
# codec settings | ||
codec = dict( | ||
type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2) | ||
|
||
# model settings | ||
model = dict( | ||
type='TopdownPoseEstimator', | ||
data_preprocessor=dict( | ||
type='PoseDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True), | ||
backbone=dict( | ||
type='HRNet', | ||
in_channels=3, | ||
extra=dict( | ||
stage1=dict( | ||
num_modules=1, | ||
num_branches=1, | ||
block='BOTTLENECK', | ||
num_blocks=(4, ), | ||
num_channels=(64, )), | ||
stage2=dict( | ||
num_modules=1, | ||
num_branches=2, | ||
block='BASIC', | ||
num_blocks=(4, 4), | ||
num_channels=(32, 64)), | ||
stage3=dict( | ||
num_modules=4, | ||
num_branches=3, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4), | ||
num_channels=(32, 64, 128)), | ||
stage4=dict( | ||
num_modules=3, | ||
num_branches=4, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4, 4), | ||
num_channels=(32, 64, 128, 256))), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='https://download.openmmlab.com/mmpose/' | ||
'pretrain_models/hrnet_w32-36af842e.pth'), | ||
), | ||
head=dict( | ||
type='HeatmapHead', | ||
in_channels=32, | ||
out_channels=133, | ||
deconv_out_channels=None, | ||
loss=dict(type='KeypointMSELoss', use_target_weight=True), | ||
decoder=codec), | ||
test_cfg=dict( | ||
flip_test=True, | ||
flip_mode='heatmap', | ||
shift_heatmap=True, | ||
)) | ||
|
||
# base dataset settings | ||
dataset_type = 'UBody2dDataset' | ||
data_mode = 'topdown' | ||
data_root = 'data/UBody/' | ||
|
||
scenes = [ | ||
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow', | ||
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing', | ||
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference' | ||
] | ||
|
||
train_datasets = [ | ||
dict( | ||
type='CocoWholeBodyDataset', | ||
data_root='data/coco/', | ||
data_mode=data_mode, | ||
ann_file='annotations/coco_wholebody_train_v1.0.json', | ||
data_prefix=dict(img='train2017/'), | ||
pipeline=[]) | ||
] | ||
|
||
for scene in scenes: | ||
train_dataset = dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_mode=data_mode, | ||
ann_file=f'annotations/{scene}/train_annotations.json', | ||
data_prefix=dict(img='images/'), | ||
pipeline=[], | ||
sample_interval=10) | ||
train_datasets.append(train_dataset) | ||
|
||
# pipelines | ||
train_pipeline = [ | ||
dict(type='LoadImage'), | ||
dict(type='GetBBoxCenterScale'), | ||
dict(type='RandomFlip', direction='horizontal'), | ||
dict(type='RandomHalfBody'), | ||
dict(type='RandomBBoxTransform'), | ||
dict(type='TopdownAffine', input_size=codec['input_size']), | ||
dict(type='GenerateTarget', encoder=codec), | ||
dict(type='PackPoseInputs') | ||
] | ||
val_pipeline = [ | ||
dict(type='LoadImage'), | ||
dict(type='GetBBoxCenterScale'), | ||
dict(type='TopdownAffine', input_size=codec['input_size']), | ||
dict(type='PackPoseInputs') | ||
] | ||
|
||
# data loaders | ||
train_dataloader = dict( | ||
batch_size=64, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
dataset=dict( | ||
type='CombinedDataset', | ||
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), | ||
datasets=train_datasets, | ||
pipeline=train_pipeline, | ||
test_mode=False, | ||
)) | ||
val_dataloader = dict( | ||
batch_size=32, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), | ||
dataset=dict( | ||
type='CocoWholeBodyDataset', | ||
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json', | ||
data_prefix=dict(img='data/coco/val2017/'), | ||
pipeline=val_pipeline, | ||
bbox_file='data/coco/person_detection_results/' | ||
'COCO_val2017_detections_AP_H_56_person.json', | ||
test_mode=True)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict( | ||
type='CocoWholeBodyMetric', | ||
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json') | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.