Skip to content

Commit 2813e89

Browse files
lyvivahhaAndroid
andauthored
[Feature] Implement fast version of YOLOX (#518)
* Implement fast version of YOLOX * config change * Update yolox_head.py * Update mmyolo/models/data_preprocessors/data_preprocessor.py Co-authored-by: Haian Huang(深度眸) <[email protected]> * Update mmyolo/models/data_preprocessors/data_preprocessor.py Co-authored-by: Haian Huang(深度眸) <[email protected]> * add test and modify faults * fix lint * fix lint * modify metafile and README * modify metafile and readme * fix * fix * fix * fix * fix * fix test --------- Co-authored-by: Haian Huang(深度眸) <[email protected]>
1 parent 031e745 commit 2813e89

14 files changed

+130
-35
lines changed

configs/yolox/README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ YOLOX-l model structure
1919

2020
## Results and Models
2121

22-
| Backbone | size | Mem (GB) | box AP | Config | Download |
23-
| :--------: | :--: | :------: | :----: | :---------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
24-
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_tiny_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
25-
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_s_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
22+
| Backbone | size | Mem (GB) | box AP | Config | Download |
23+
| :--------: | :--: | :------: | :----: | :--------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
24+
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
25+
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_s_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
2626

2727
**Note**:
2828

configs/yolox/metafile.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ Collections:
2020

2121

2222
Models:
23-
- Name: yolox_tiny_8xb8-300e_coco
23+
- Name: yolox_tiny_fast_8xb8-300e_coco
2424
In Collection: YOLOX
25-
Config: configs/yolox/yolox_tiny_8xb8-300e_coco.py
25+
Config: configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py
2626
Metadata:
2727
Training Memory (GB): 2.8
2828
Epochs: 300
@@ -32,9 +32,9 @@ Models:
3232
Metrics:
3333
box AP: 32.7
3434
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth
35-
- Name: yolox_s_8xb8-300e_coco
35+
- Name: yolox_s_fast_8xb8-300e_coco
3636
In Collection: YOLOX
37-
Config: configs/yolox/yolox_s_8xb8-300e_coco.py
37+
Config: configs/yolox/yolox_s_fast_8xb8-300e_coco.py
3838
Metadata:
3939
Training Memory (GB): 5.6
4040
Epochs: 300

configs/yolox/yolox_l_8xb8-300e_coco.py renamed to configs/yolox/yolox_l_fast_8xb8-300e_coco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
_base_ = './yolox_s_8xb8-300e_coco.py'
1+
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
22

33
deepen_factor = 1.0
44
widen_factor = 1.0

configs/yolox/yolox_m_8xb8-300e_coco.py renamed to configs/yolox/yolox_m_fast_8xb8-300e_coco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
_base_ = './yolox_s_8xb8-300e_coco.py'
1+
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
22

33
deepen_factor = 0.67
44
widen_factor = 0.75

configs/yolox/yolox_nano_8xb8-300e_coco.py renamed to configs/yolox/yolox_nano_fast_8xb8-300e_coco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
_base_ = './yolox_tiny_8xb8-300e_coco.py'
1+
_base_ = './yolox_tiny_fast_8xb8-300e_coco.py'
22

33
deepen_factor = 0.33
44
widen_factor = 0.25

configs/yolox/yolox_s_8xb8-300e_coco.py renamed to configs/yolox/yolox_s_fast_8xb8-300e_coco.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
# TODO: Waiting for mmengine support
3030
use_syncbn=False,
3131
data_preprocessor=dict(
32-
type='mmdet.DetDataPreprocessor',
32+
type='YOLOv5DetDataPreprocessor',
3333
pad_size_divisor=32,
3434
batch_augments=[
3535
dict(
36-
type='mmdet.BatchSyncRandomResize',
36+
type='YOLOXBatchSyncRandomResize',
3737
random_size_range=(480, 800),
3838
size_divisor=32,
3939
interval=10)
@@ -157,6 +157,7 @@
157157
num_workers=train_num_workers,
158158
persistent_workers=True,
159159
pin_memory=True,
160+
collate_fn=dict(type='yolov5_collate'),
160161
sampler=dict(type='DefaultSampler', shuffle=True),
161162
dataset=dict(
162163
type=dataset_type,

configs/yolox/yolox_tiny_8xb8-300e_coco.py renamed to configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
_base_ = './yolox_s_8xb8-300e_coco.py'
1+
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
22

33
deepen_factor = 0.33
44
widen_factor = 0.375

configs/yolox/yolox_x_8xb8-300e_coco.py renamed to configs/yolox/yolox_x_fast_8xb8-300e_coco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
_base_ = './yolox_s_8xb8-300e_coco.py'
1+
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
22

33
deepen_factor = 1.33
44
widen_factor = 1.25
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .data_preprocessor import (PPYOLOEBatchRandomResize,
33
PPYOLOEDetDataPreprocessor,
4-
YOLOv5DetDataPreprocessor)
4+
YOLOv5DetDataPreprocessor,
5+
YOLOXBatchSyncRandomResize)
56

67
__all__ = [
78
'YOLOv5DetDataPreprocessor', 'PPYOLOEDetDataPreprocessor',
8-
'PPYOLOEBatchRandomResize'
9+
'PPYOLOEBatchRandomResize', 'YOLOXBatchSyncRandomResize'
910
]

mmyolo/models/data_preprocessors/data_preprocessor.py

+41
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,47 @@
1616
None]
1717

1818

19+
@MODELS.register_module()
20+
class YOLOXBatchSyncRandomResize(BatchSyncRandomResize):
21+
"""YOLOX batch random resize.
22+
23+
Args:
24+
random_size_range (tuple): The multi-scale random range during
25+
multi-scale training.
26+
interval (int): The iter interval of change
27+
image size. Defaults to 10.
28+
size_divisor (int): Image size divisible factor.
29+
Defaults to 32.
30+
"""
31+
32+
def forward(self, inputs: Tensor, data_samples: dict) -> Tensor and dict:
33+
"""resize a batch of images and bboxes to shape ``self._input_size``"""
34+
h, w = inputs.shape[-2:]
35+
inputs = inputs.float()
36+
assert isinstance(data_samples, dict)
37+
38+
if self._input_size is None:
39+
self._input_size = (h, w)
40+
scale_y = self._input_size[0] / h
41+
scale_x = self._input_size[1] / w
42+
if scale_x != 1 or scale_y != 1:
43+
inputs = F.interpolate(
44+
inputs,
45+
size=self._input_size,
46+
mode='bilinear',
47+
align_corners=False)
48+
49+
data_samples['bboxes_labels'][:, 2::2] *= scale_x
50+
data_samples['bboxes_labels'][:, 3::2] *= scale_y
51+
52+
message_hub = MessageHub.get_current_instance()
53+
if (message_hub.get_info('iter') + 1) % self._interval == 0:
54+
self._input_size = self._get_random_size(
55+
aspect_ratio=float(w / h), device=inputs.device)
56+
57+
return inputs, data_samples
58+
59+
1960
@MODELS.register_module()
2061
class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
2162
"""Rewrite collate_fn to get faster training speed.

mmyolo/models/dense_heads/yolox_head.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def loss_by_feat(
265265
cls_scores: Sequence[Tensor],
266266
bbox_preds: Sequence[Tensor],
267267
objectnesses: Sequence[Tensor],
268-
batch_gt_instances: Sequence[InstanceData],
268+
batch_gt_instances: Tensor,
269269
batch_img_metas: Sequence[dict],
270270
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
271271
"""Calculate the loss based on the features extracted by the detection
@@ -297,6 +297,9 @@ def loss_by_feat(
297297
if batch_gt_instances_ignore is None:
298298
batch_gt_instances_ignore = [None] * num_imgs
299299

300+
batch_gt_instances = self.gt_instances_preprocess(
301+
batch_gt_instances, len(batch_img_metas))
302+
300303
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
301304
mlvl_priors = self.prior_generator.grid_priors(
302305
featmap_sizes,
@@ -484,3 +487,28 @@ def _get_bbox_aux_target(self,
484487
bbox_aux_target[:,
485488
2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
486489
return bbox_aux_target
490+
491+
@staticmethod
492+
def gt_instances_preprocess(batch_gt_instances: Tensor,
493+
batch_size: int) -> List[InstanceData]:
494+
"""Split batch_gt_instances with batch size.
495+
496+
Args:
497+
batch_gt_instances (Tensor): Ground truth
498+
a 2D-Tensor for whole batch, shape [all_gt_bboxes, 6]
499+
batch_size (int): Batch size.
500+
501+
Returns:
502+
List: batch gt instances data, shape [batch_size, InstanceData]
503+
"""
504+
# faster version
505+
batch_instance_list = []
506+
for i in range(batch_size):
507+
batch_gt_instance_ = InstanceData()
508+
single_batch_instance = \
509+
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
510+
batch_gt_instance_.bboxes = single_batch_instance[:, 2:]
511+
batch_gt_instance_.labels = single_batch_instance[:, 1]
512+
batch_instance_list.append(batch_gt_instance_)
513+
514+
return batch_instance_list

tests/test_models/test_data_preprocessor/test_data_preprocessor.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from mmengine import MessageHub
77

88
from mmyolo.models import PPYOLOEBatchRandomResize, PPYOLOEDetDataPreprocessor
9-
from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor
9+
from mmyolo.models.data_preprocessors import (YOLOv5DetDataPreprocessor,
10+
YOLOXBatchSyncRandomResize)
1011
from mmyolo.utils import register_all_modules
1112

1213
register_all_modules()
@@ -125,3 +126,31 @@ def test_batch_random_resize(self):
125126
# data_samples must be list
126127
with self.assertRaises(AssertionError):
127128
processor(data, training=True)
129+
130+
131+
class TestYOLOXDetDataPreprocessor(TestCase):
132+
133+
def test_batch_sync_random_size(self):
134+
processor = YOLOXBatchSyncRandomResize(
135+
random_size_range=(480, 800), size_divisor=32, interval=1)
136+
self.assertTrue(isinstance(processor, YOLOXBatchSyncRandomResize))
137+
message_hub = MessageHub.get_instance(
138+
'test_yolox_batch_sync_random_resize')
139+
message_hub.update_info('iter', 0)
140+
141+
# test training
142+
inputs = torch.randint(0, 256, (4, 3, 10, 11))
143+
data_samples = {'bboxes_labels': torch.randint(0, 11, (18, 6)).float()}
144+
145+
inputs, data_samples = processor(inputs, data_samples)
146+
147+
self.assertIn('bboxes_labels', data_samples)
148+
self.assertIsInstance(data_samples['bboxes_labels'], torch.Tensor)
149+
self.assertIsInstance(inputs, torch.Tensor)
150+
151+
inputs = torch.randint(0, 256, (4, 3, 10, 11))
152+
data_samples = DetDataSample()
153+
154+
# data_samples must be dict
155+
with self.assertRaises(AssertionError):
156+
processor(inputs, data_samples)

tests/test_models/test_dense_heads/test_yolox_head.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from mmengine.config import Config
66
from mmengine.model import bias_init_with_prob
7-
from mmengine.structures import InstanceData
87
from mmengine.testing import assert_allclose
98

109
from mmyolo.models.dense_heads import YOLOXHead
@@ -98,11 +97,10 @@ def test_loss_by_feat(self):
9897

9998
# Test that empty ground truth encourages the network to predict
10099
# background
101-
gt_instances = InstanceData(
102-
bboxes=torch.empty((0, 4)), labels=torch.LongTensor([]))
100+
gt_instances = torch.empty((0, 6))
103101

104102
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
105-
objectnesses, [gt_instances],
103+
objectnesses, gt_instances,
106104
img_metas)
107105
# When there is no truth, the cls loss should be nonzero but there
108106
# should be no box loss.
@@ -122,12 +120,11 @@ def test_loss_by_feat(self):
122120
# for random inputs
123121
head = YOLOXHead(head_module=self.head_module, train_cfg=train_cfg)
124122
head.use_bbox_aux = True
125-
gt_instances = InstanceData(
126-
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
127-
labels=torch.LongTensor([2]))
123+
gt_instances = torch.Tensor(
124+
[[0, 2, 23.6667, 23.8757, 238.6326, 151.8874]])
128125

129126
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses,
130-
[gt_instances], img_metas)
127+
gt_instances, img_metas)
131128
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
132129
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
133130
onegt_obj_loss = one_gt_losses['loss_obj'].sum()
@@ -142,11 +139,10 @@ def test_loss_by_feat(self):
142139
'l1 loss should be non-zero')
143140

144141
# Test groud truth out of bound
145-
gt_instances = InstanceData(
146-
bboxes=torch.Tensor([[s * 4, s * 4, s * 4 + 10, s * 4 + 10]]),
147-
labels=torch.LongTensor([2]))
142+
gt_instances = torch.Tensor(
143+
[[0, 2, s * 4, s * 4, s * 4 + 10, s * 4 + 10]])
148144
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
149-
objectnesses, [gt_instances],
145+
objectnesses, gt_instances,
150146
img_metas)
151147
# When gt_bboxes out of bound, the assign results should be empty,
152148
# so the cls and bbox loss should be zero.

tests/test_models/test_detectors/test_yolo_detector.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def setUp(self):
2121
@parameterized.expand([
2222
'yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py',
2323
'yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py',
24-
'yolox/yolox_tiny_8xb8-300e_coco.py',
24+
'yolox/yolox_tiny_fast_8xb8-300e_coco.py',
2525
'rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py',
2626
'yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py',
2727
'yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py'
@@ -38,7 +38,6 @@ def test_init(self, cfg_file):
3838

3939
@parameterized.expand([
4040
('yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py', ('cuda', 'cpu')),
41-
('yolox/yolox_s_8xb8-300e_coco.py', ('cuda', 'cpu')),
4241
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
4342
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
4443
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))
@@ -79,7 +78,7 @@ def test_forward_loss_mode(self, cfg_file, devices):
7978
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
8079
'cpu')),
8180
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
82-
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
81+
('yolox/yolox_tiny_fast_8xb8-300e_coco.py', ('cuda', 'cpu')),
8382
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
8483
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
8584
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))
@@ -112,7 +111,7 @@ def test_forward_predict_mode(self, cfg_file, devices):
112111
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
113112
'cpu')),
114113
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
115-
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
114+
('yolox/yolox_tiny_fast_8xb8-300e_coco.py', ('cuda', 'cpu')),
116115
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
117116
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
118117
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))

0 commit comments

Comments
 (0)