Skip to content

Commit 5c52f99

Browse files
committed
add single depth to 3d hand keypoints, add nyu dataset, awr network
1 parent afb37d4 commit 5c52f99

File tree

19 files changed

+2060
-15
lines changed

19 files changed

+2060
-15
lines changed

Diff for: configs/_base_/datasets/nyu.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
dataset_info = dict(
2+
dataset_name='nyu',
3+
paper_info=dict(
4+
author='Jonathan Tompson and Murphy Stein and Yann Lecun and '
5+
'Ken Perlin',
6+
title='Real-Time Continuous Pose Recovery of Human Hands '
7+
'Using Convolutional Networks',
8+
container='ACM Transactions on Graphics',
9+
year='2014',
10+
homepage='https://jonathantompson.github.io/NYU_Hand_Pose_Dataset.htm',
11+
),
12+
keypoint_info={
13+
0: dict(name='F1_KNU3_A', id=0, color=[255, 128, 0], type='', swap=''),
14+
1: dict(name='F1_KNU3_B', id=1, color=[255, 128, 0], type='', swap=''),
15+
2: dict(name='F1_KNU2_A', id=2, color=[255, 128, 0], type='', swap=''),
16+
3: dict(name='F1_KNU2_B', id=3, color=[255, 128, 0], type='', swap=''),
17+
4:
18+
dict(name='F1_KNU1_A', id=4, color=[255, 153, 255], type='', swap=''),
19+
5:
20+
dict(name='F1_KNU1_B', id=5, color=[255, 153, 255], type='', swap=''),
21+
6:
22+
dict(name='F2_KNU3_A', id=6, color=[255, 153, 255], type='', swap=''),
23+
7:
24+
dict(name='F2_KNU3_B', id=7, color=[255, 153, 255], type='', swap=''),
25+
8:
26+
dict(name='F2_KNU2_A', id=8, color=[102, 178, 255], type='', swap=''),
27+
9:
28+
dict(name='F2_KNU2_B', id=9, color=[102, 178, 255], type='', swap=''),
29+
10:
30+
dict(name='F2_KNU1_A', id=10, color=[102, 178, 255], type='', swap=''),
31+
11:
32+
dict(name='F2_KNU1_B', id=11, color=[102, 178, 255], type='', swap=''),
33+
12:
34+
dict(name='F3_KNU3_A', id=12, color=[255, 51, 51], type='', swap=''),
35+
13:
36+
dict(name='F3_KNU3_B', id=13, color=[255, 51, 51], type='', swap=''),
37+
14:
38+
dict(name='F3_KNU2_A', id=14, color=[255, 51, 51], type='', swap=''),
39+
15:
40+
dict(name='F3_KNU2_B', id=15, color=[255, 51, 51], type='', swap=''),
41+
16: dict(name='F3_KNU1_A', id=16, color=[0, 255, 0], type='', swap=''),
42+
17: dict(name='F3_KNU1_B', id=17, color=[0, 255, 0], type='', swap=''),
43+
18: dict(name='F4_KNU3_A', id=18, color=[0, 255, 0], type='', swap=''),
44+
19: dict(name='F4_KNU3_B', id=19, color=[0, 255, 0], type='', swap=''),
45+
20:
46+
dict(name='F4_KNU2_A', id=20, color=[255, 255, 255], type='', swap=''),
47+
21:
48+
dict(name='F4_KNU2_B', id=21, color=[255, 128, 0], type='', swap=''),
49+
22:
50+
dict(name='F4_KNU1_A', id=22, color=[255, 128, 0], type='', swap=''),
51+
23:
52+
dict(name='F4_KNU1_B', id=23, color=[255, 128, 0], type='', swap=''),
53+
24:
54+
dict(name='TH_KNU3_A', id=24, color=[255, 128, 0], type='', swap=''),
55+
25:
56+
dict(name='TH_KNU3_B', id=25, color=[255, 153, 255], type='', swap=''),
57+
26:
58+
dict(name='TH_KNU2_A', id=26, color=[255, 153, 255], type='', swap=''),
59+
27:
60+
dict(name='TH_KNU2_B', id=27, color=[255, 153, 255], type='', swap=''),
61+
28:
62+
dict(name='TH_KNU1_A', id=28, color=[255, 153, 255], type='', swap=''),
63+
29:
64+
dict(name='TH_KNU1_B', id=29, color=[102, 178, 255], type='', swap=''),
65+
30:
66+
dict(name='PALM_1', id=30, color=[102, 178, 255], type='', swap=''),
67+
31:
68+
dict(name='PALM_2', id=31, color=[102, 178, 255], type='', swap=''),
69+
32:
70+
dict(name='PALM_3', id=32, color=[102, 178, 255], type='', swap=''),
71+
33: dict(name='PALM_4', id=33, color=[255, 51, 51], type='', swap=''),
72+
34: dict(name='PALM_5', id=34, color=[255, 51, 51], type='', swap=''),
73+
35: dict(name='PALM_6', id=35, color=[255, 51, 51], type='', swap=''),
74+
},
75+
skeleton_info={
76+
0: dict(link=('PALM_3', 'F1_KNU2_B'), id=0, color=[255, 128, 0]),
77+
1: dict(link=('F1_KNU2_B', 'F1_KNU3_A'), id=1, color=[255, 128, 0]),
78+
2: dict(link=('PALM_3', 'F2_KNU2_B'), id=2, color=[255, 128, 0]),
79+
3: dict(link=('F2_KNU2_B', 'F2_KNU3_A'), id=3, color=[255, 128, 0]),
80+
4: dict(link=('PALM_3', 'F3_KNU2_B'), id=4, color=[255, 153, 255]),
81+
5: dict(link=('F3_KNU2_B', 'F3_KNU3_A'), id=5, color=[255, 153, 255]),
82+
6: dict(link=('PALM_3', 'F4_KNU2_B'), id=6, color=[255, 153, 255]),
83+
7: dict(link=('F4_KNU2_B', 'F4_KNU3_A'), id=7, color=[255, 153, 255]),
84+
8: dict(link=('PALM_3', 'TH_KNU2_B'), id=8, color=[102, 178, 255]),
85+
9: dict(link=('TH_KNU2_B', 'TH_KNU3_B'), id=9, color=[102, 178, 255]),
86+
10:
87+
dict(link=('TH_KNU3_B', 'TH_KNU3_A'), id=10, color=[102, 178, 255]),
88+
11: dict(link=('PALM_3', 'PALM_1'), id=11, color=[102, 178, 255]),
89+
12: dict(link=('PALM_3', 'PALM_2'), id=12, color=[255, 51, 51]),
90+
},
91+
joint_weights=[1.] * 36,
92+
sigmas=[])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
_base_ = [
2+
'../../../../_base_/default_runtime.py',
3+
'../../../../_base_/datasets/nyu.py'
4+
]
5+
checkpoint_config = dict(interval=1)
6+
# TODO: metric
7+
evaluation = dict(
8+
interval=1,
9+
metric=['MRRPE', 'MPJPE', 'Handedness_acc'],
10+
save_best='MPJPE_all')
11+
12+
optimizer = dict(
13+
type='Adam',
14+
lr=2e-4,
15+
)
16+
optimizer_config = dict(grad_clip=None)
17+
# learning policy
18+
lr_config = dict(policy='step', step=[15, 17])
19+
total_epochs = 20
20+
log_config = dict(
21+
interval=20,
22+
hooks=[
23+
dict(type='TextLoggerHook'),
24+
# dict(type='TensorboardLoggerHook')
25+
])
26+
27+
load_from = '/root/mmpose/data/ckpt/new_res50.pth'
28+
used_keypoints_index = [0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 27, 30, 31, 32]
29+
30+
channel_cfg = dict(
31+
num_output_channels=14,
32+
dataset_joints=36,
33+
dataset_channel=used_keypoints_index,
34+
inference_channel=used_keypoints_index)
35+
36+
# model settings
37+
model = dict(
38+
type='Depthhand3D', # pretrained=None
39+
backbone=dict(
40+
type='AWRResNet',
41+
depth=50,
42+
frozen_stages=-1,
43+
zero_init_residual=False,
44+
in_channels=1),
45+
keypoint_head=dict(
46+
type='AdaptiveWeightingRegression3DHead',
47+
offset_head_cfg=dict(
48+
in_channels=256,
49+
out_channels_vector=42,
50+
out_channels_scalar=14,
51+
heatmap_kernel_size=1.0,
52+
),
53+
deconv_head_cfg=dict(
54+
in_channels=2048,
55+
out_channels=256,
56+
depth_size=64,
57+
num_deconv_layers=3,
58+
num_deconv_filters=(256, 256, 256),
59+
num_deconv_kernels=(4, 4, 4),
60+
extra=dict(final_conv_kernel=0, )),
61+
loss_offset=dict(type='AWRSmoothL1Loss', use_target_weight=False),
62+
loss_keypoint=dict(type='AWRSmoothL1Loss', use_target_weight=True),
63+
),
64+
train_cfg=dict(use_img_for_head=True),
65+
test_cfg=dict(use_img_for_head=True, flip_test=False))
66+
67+
data_cfg = dict(
68+
image_size=[128, 128],
69+
heatmap_size=[64, 64, 56],
70+
cube_size=[300, 300, 300],
71+
heatmap_size_root=64,
72+
num_output_channels=channel_cfg['num_output_channels'],
73+
num_joints=channel_cfg['dataset_joints'],
74+
dataset_channel=channel_cfg['dataset_channel'],
75+
inference_channel=channel_cfg['inference_channel'])
76+
77+
train_pipeline = [
78+
dict(type='LoadImageFromFile', color_type='unchanged'),
79+
dict(type='TopDownGetBboxCenterScale', padding=1.0),
80+
dict(type='TopDownAffine'),
81+
dict(type='DepthToTensor'),
82+
dict(
83+
type='MultitaskGatherTarget',
84+
pipeline_list=[
85+
[
86+
dict(
87+
type='TopDownGenerateTargetRegression',
88+
use_zero_mean=True,
89+
joint_indices=used_keypoints_index,
90+
is_3d=True,
91+
normalize_depth=True,
92+
),
93+
dict(
94+
type='HandGenerateJointToOffset',
95+
heatmap_kernel_size=1.0,
96+
)
97+
],
98+
[
99+
dict(
100+
type='TopDownGenerateTargetRegression',
101+
use_zero_mean=True,
102+
joint_indices=used_keypoints_index,
103+
is_3d=True,
104+
normalize_depth=True,
105+
)
106+
],
107+
],
108+
pipeline_indices=[0, 1],
109+
),
110+
dict(
111+
type='Collect',
112+
keys=['img', 'target', 'target_weight'],
113+
meta_keys=[
114+
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
115+
'rotation', 'flip_pairs', 'cube_size', 'center_depth', 'focal',
116+
'princpt', 'image_size', 'joints_cam', 'dataset_channel',
117+
'joints_uvd'
118+
]),
119+
]
120+
121+
val_pipeline = [
122+
dict(type='LoadImageFromFile', color_type='unchanged'),
123+
dict(type='TopDownGetBboxCenterScale', padding=1.0),
124+
dict(type='TopDownAffine'),
125+
dict(type='DepthToTensor'),
126+
dict(
127+
type='Collect',
128+
keys=['img'],
129+
meta_keys=[
130+
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
131+
'rotation', 'flip_pairs', 'cube_size', 'center_depth', 'focal',
132+
'princpt', 'image_size', 'joints_cam', 'dataset_channel',
133+
'joints_uvd'
134+
])
135+
]
136+
137+
test_pipeline = val_pipeline
138+
139+
data_root = 'data/nyu'
140+
data = dict(
141+
samples_per_gpu=4,
142+
workers_per_gpu=0,
143+
shuffle=False,
144+
train=dict(
145+
type='NYUHandDataset',
146+
ann_file=f'{data_root}/annotations/nyu_test_data.json',
147+
camera_file=f'{data_root}/annotations/nyu_test_camera.json',
148+
joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json',
149+
img_prefix=f'{data_root}/images/test/',
150+
data_cfg=data_cfg,
151+
use_refined_center=False,
152+
align_uvd_xyz_direction=True,
153+
pipeline=train_pipeline,
154+
dataset_info={{_base_.dataset_info}}),
155+
val=dict(
156+
type='NYUHandDataset',
157+
ann_file=f'{data_root}/annotations/nyu_test_data.json',
158+
camera_file=f'{data_root}/annotations/nyu_test_camera.json',
159+
joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json',
160+
img_prefix=f'{data_root}/images/test/',
161+
data_cfg=data_cfg,
162+
use_refined_center=False,
163+
align_uvd_xyz_direction=True,
164+
pipeline=val_pipeline,
165+
dataset_info={{_base_.dataset_info}}),
166+
test=dict(
167+
type='NYUHandDataset',
168+
ann_file=f'{data_root}/annotations/nyu_test_data.json',
169+
camera_file=f'{data_root}/annotations/nyu_test_camera.json',
170+
joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json',
171+
img_prefix=f'{data_root}/images/test/',
172+
data_cfg=data_cfg,
173+
use_refined_center=False,
174+
align_uvd_xyz_direction=True,
175+
pipeline=test_pipeline,
176+
dataset_info={{_base_.dataset_info}}),
177+
)

Diff for: mmpose/core/evaluation/top_down_eval.py

+35
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,41 @@ def keypoints_from_heatmaps3d(heatmaps, center, scale):
655655
return preds, maxvals
656656

657657

658+
def keypoints_from_joint_uvd(joint_uvd, center, scale, image_size):
659+
"""Get final keypoint predictions from 3d heatmaps and transform them back
660+
to the image.
661+
662+
Note:
663+
- batch size: N
664+
- num keypoints: K
665+
- heatmap depth size: D
666+
- heatmap height: H
667+
- heatmap width: W
668+
669+
Args:
670+
heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.
671+
center (np.ndarray[N, 2]): Center of the bounding box (x, y).
672+
scale (np.ndarray[N, 2]): Scale of the bounding box
673+
wrt height/width.
674+
675+
Returns:
676+
tuple: A tuple containing keypoint predictions and scores.
677+
678+
- preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \
679+
in images.
680+
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
681+
"""
682+
N, K, D = joint_uvd.shape
683+
preds = joint_uvd
684+
maxvals = np.ones((N, K, 1), dtype=np.float32)
685+
# Transform back to the image
686+
for i in range(N):
687+
preds[i, :, :2] = transform_preds(
688+
(preds[i, :, :2] + 1) * image_size[i] / 2, center[i], scale[i],
689+
[image_size[i, 1], image_size[i, 0]])
690+
return preds, maxvals
691+
692+
658693
def multilabel_classification_accuracy(pred, gt, mask, thr=0.5):
659694
"""Get multi-label classification accuracy.
660695

Diff for: mmpose/datasets/datasets/base/__init__.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66
from .kpt_2d_sview_rgb_vid_top_down_dataset import \
77
Kpt2dSviewRgbVidTopDownDataset
88
from .kpt_3d_mview_rgb_img_direct_dataset import Kpt3dMviewRgbImgDirectDataset
9+
from .kpt_3d_sview_depth_img_top_down_dataset import \
10+
Kpt3dSviewDepthImgTopDownDataset
911
from .kpt_3d_sview_kpt_2d_dataset import Kpt3dSviewKpt2dDataset
1012
from .kpt_3d_sview_rgb_img_top_down_dataset import \
1113
Kpt3dSviewRgbImgTopDownDataset
1214

1315
__all__ = [
14-
'Kpt3dMviewRgbImgDirectDataset', 'Kpt2dSviewRgbImgTopDownDataset',
15-
'Kpt3dSviewRgbImgTopDownDataset', 'Kpt2dSviewRgbImgBottomUpDataset',
16-
'Kpt3dSviewKpt2dDataset', 'Kpt2dSviewRgbVidTopDownDataset'
16+
'Kpt3dMviewRgbImgDirectDataset',
17+
'Kpt2dSviewRgbImgTopDownDataset',
18+
'Kpt3dSviewRgbImgTopDownDataset',
19+
'Kpt2dSviewRgbImgBottomUpDataset',
20+
'Kpt3dSviewKpt2dDataset',
21+
'Kpt2dSviewRgbVidTopDownDataset',
22+
'Kpt3dSviewDepthImgTopDownDataset',
1723
]

0 commit comments

Comments
 (0)