Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Refine the skeleton extraction script #150

Merged
merged 6 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ This repo is the official implementation of [PoseConv3D](https://arxiv.org/abs/2
</figure>
</div>

## News
## Change Log

- Improve skeleton extraction script ([PR](https://github.com/kennymckormick/pyskl/pull/150)). Now it supports non-distributed skeleton extraction and k400-style (**2023-03-20**).
- Support PyTorch 2.0: when set `--compile` for training/testing scripts and with `torch.__version__ >= 'v2.0.0'` detected, will use `torch.compile` to compile the model before training/testing. Experimental Feature, absolutely no performance warranty (**2023-03-16**).
- Provide a real-time gesture recognition demo based on skeleton-based action recognition with ST-GCN++, check [Demo](/demo/demo.md) for more details and instructions (**2023-02-10**).
- Provide [scripts](/examples/inference_speed.ipynb) to estimate the inference speed of each model (**2022-12-30**).
- Support [RGBPoseConv3D](https://arxiv.org/abs/2104.13586), a two-stream 3D-CNN for action recognition based on RGB & Human Skeleton. Follow the [guide](/configs/rgbpose_conv3d/README.md) to train and test RGBPoseConv3D on NTURGB+D (**2022-12-29**).
- We provide a script ([ntu_preproc.py](/tools/data/ntu_preproc.py)) to generate PYSKL-style annotations files from official NTURGB+D skeleton files (**2022-12-20**).

## Supported Algorithms

Expand Down
6 changes: 3 additions & 3 deletions pyskl/datasets/pose_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
pipeline,
split=None,
valid_ratio=None,
box_thr=0.5,
box_thr=None,
class_prob=None,
memcached=False,
mc_cfg=('localhost', 22077),
Expand All @@ -65,13 +65,13 @@ def __init__(self,

# Thresholding Training Examples
self.valid_ratio = valid_ratio
if self.valid_ratio is not None:
assert isinstance(self.valid_ratio, float)
if self.valid_ratio is not None and isinstance(self.valid_ratio, float) and self.valid_ratio > 0:
self.video_infos = [
x for x in self.video_infos
if x['valid'][self.box_thr] / x['total_frames'] >= valid_ratio
]
for item in self.video_infos:
assert 'box_score' in item, 'if valid_ratio is a positive number, item should have field `box_score`'
anno_inds = (item['box_score'] >= self.box_thr)
item['anno_inds'] = anno_inds
for item in self.video_infos:
Expand Down
99 changes: 62 additions & 37 deletions tools/data/custom_2d_skeleton.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy as cp
import decord
import mmcv
import numpy as np
Expand Down Expand Up @@ -56,19 +57,36 @@ def detection_inference(model, frames):
return results


def pose_inference(model, frames, det_results):
def pose_inference(anno_in, model, frames, det_results, compress=False):
anno = cp.deepcopy(anno_in)
assert len(frames) == len(det_results)
total_frames = len(frames)
num_person = max([len(x) for x in det_results])
kp = np.zeros((num_person, total_frames, 17, 3), dtype=np.float32)

for i, (f, d) in enumerate(zip(frames, det_results)):
# Align input format
d = [dict(bbox=x) for x in list(d)]
pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
for j, item in enumerate(pose):
kp[j, i] = item['keypoints']
return kp
anno['total_frames'] = total_frames
anno['num_person_raw'] = num_person

if compress:
kp, frame_inds = [], []
for i, (f, d) in enumerate(zip(frames, det_results)):
# Align input format
d = [dict(bbox=x) for x in list(d)]
pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
for j, item in enumerate(pose):
kp.append(item['keypoints'])
frame_inds.append(i)
anno['keypoint'] = np.stack(kp).astype(np.float16)
anno['frame_inds'] = np.array(frame_inds, dtype=np.int16)
else:
kp = np.zeros((num_person, total_frames, 17, 3), dtype=np.float32)
for i, (f, d) in enumerate(zip(frames, det_results)):
# Align input format
d = [dict(bbox=x) for x in list(d)]
pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
for j, item in enumerate(pose):
kp[j, i] = item['keypoints']
anno['keypoint'] = kp[..., :2].astype(np.float16)
anno['keypoint_score'] = kp[..., 2].astype(np.float16)
return anno


def parse_args():
Expand All @@ -95,6 +113,9 @@ def parse_args():
parser.add_argument('--out', type=str, help='output pickle name')
parser.add_argument('--tmpdir', type=str, default='tmp')
parser.add_argument('--local_rank', type=int, default=0)
# * When non-dist is set, will only use 1 GPU
parser.add_argument('--non-dist', action='store_true', help='whether to use distributed skeleton extraction')
parser.add_argument('--compress', action='store_true', help='whether to do K400-style compression')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
Expand All @@ -116,18 +137,22 @@ def main():
else:
annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0], label=int(x[1])) for x in lines]

init_dist('pytorch', backend='nccl')
rank, world_size = get_dist_info()

if rank == 0:
if args.non_dist:
my_part = annos
os.makedirs(args.tmpdir, exist_ok=True)
dist.barrier()
my_part = annos[rank::world_size]
else:
init_dist('pytorch', backend='nccl')
rank, world_size = get_dist_info()
if rank == 0:
os.makedirs(args.tmpdir, exist_ok=True)
dist.barrier()
my_part = annos[rank::world_size]

det_model = init_detector(args.det_config, args.det_ckpt, 'cuda')
assert det_model.CLASSES[0] == 'person', 'A detector trained on COCO is required'
pose_model = init_pose_model(args.pose_config, args.pose_ckpt, 'cuda')

results = []
for anno in tqdm(my_part):
frames = extract_frame(anno['filename'])
det_results = detection_inference(det_model, frames)
Expand All @@ -142,30 +167,30 @@ def main():
res = res[box_areas >= args.det_area_thr]
det_results[i] = res

pose_results = pose_inference(pose_model, frames, det_results)
shape = frames[0].shape[:2]
anno['img_shape'] = anno['original_shape'] = shape
anno['total_frames'] = len(frames)
anno['num_person_raw'] = pose_results.shape[0]
anno['keypoint'] = pose_results[..., :2].astype(np.float16)
anno['keypoint_score'] = pose_results[..., 2].astype(np.float16)
anno['img_shape'] = shape
anno = pose_inference(anno, pose_model, frames, det_results, compress=args.compress)
anno.pop('filename')
results.append(anno)

mmcv.dump(my_part, osp.join(args.tmpdir, f'part_{rank}.pkl'))
dist.barrier()

if rank == 0:
parts = [mmcv.load(osp.join(args.tmpdir, f'part_{i}.pkl')) for i in range(world_size)]
rem = len(annos) % world_size
if rem:
for i in range(rem, world_size):
parts[i].append(None)

ordered_results = []
for res in zip(*parts):
ordered_results.extend(list(res))
ordered_results = ordered_results[:len(annos)]
mmcv.dump(ordered_results, args.out)
if args.non_dist:
mmcv.dump(results, args.out)
else:
mmcv.dump(results, osp.join(args.tmpdir, f'part_{rank}.pkl'))
dist.barrier()

if rank == 0:
parts = [mmcv.load(osp.join(args.tmpdir, f'part_{i}.pkl')) for i in range(world_size)]
rem = len(annos) % world_size
if rem:
for i in range(rem, world_size):
parts[i].append(None)

ordered_results = []
for res in zip(*parts):
ordered_results.extend(list(res))
ordered_results = ordered_results[:len(annos)]
mmcv.dump(ordered_results, args.out)


if __name__ == '__main__':
Expand Down