Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improve] Beauty RTMDet config #531

Merged
merged 2 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 98 additions & 45 deletions configs/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,92 @@
_base_ = '../_base_/default_runtime.py'

# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'

img_scale = (640, 640) # width, height
deepen_factor = 1.0
widen_factor = 1.0
max_epochs = 300
stage2_num_epochs = 20
interval = 10
num_classes = 80
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path

num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 10
val_batch_size_per_gpu = 32
val_num_workers = 10
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
strides = [8, 16, 32]

# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
base_lr = 0.004
max_epochs = 300 # Maximum training epochs
# Change train_pipeline for final 20 epochs (stage 2)
num_epochs_stage2 = 20

# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
model_test_cfg = dict(
# The config of multi-label for multi-class prediction.
multi_label=True,
# The number of boxes before NMS
nms_pre=30000,
score_thr=0.001, # Threshold to filter out boxes.
nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold
max_per_img=300) # Max number of detections of each image

# only on Val
# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# ratio range for random resize
random_resize_ratio_range = (0.1, 2.0)
# Cached images number in mosaic
mosaic_max_cached_images = 40
# Number of cached images in mixup
mixup_max_cached_images = 20
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 10

# Config of batch shapes. Only on val.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1.0
# The scaling factor that controls the width of the network structure
widen_factor = 1.0
# Strides of multi-scale prior box
strides = [8, 16, 32]

norm_cfg = dict(type='BN') # Normalization config

# -----train val related-----
lr_start_factor = 1.0e-5
dsl_topk = 13 # Number of bbox selected in each level
loss_cls_weight = 1.0
loss_bbox_weight = 2.0
qfl_beta = 2.0 # beta of QualityFocalLoss
weight_decay = 0.05

# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# validation intervals in stage 2
val_interval_stage2 = 1
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)

# ===============================Unmodified in most cases====================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
Expand All @@ -46,7 +101,7 @@
deepen_factor=deepen_factor,
widen_factor=widen_factor,
channel_attention=True,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='CSPNeXtPAFPN',
Expand All @@ -56,7 +111,7 @@
out_channels=256,
num_csp_blocks=3,
expand_ratio=0.5,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='RTMDetHead',
Expand All @@ -66,7 +121,7 @@
in_channels=256,
stacked_convs=2,
feat_channels=256,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True),
share_conv=True,
pred_kernel_size=1,
Expand All @@ -77,24 +132,19 @@
loss_cls=dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=2.0)),
beta=qfl_beta,
loss_weight=loss_cls_weight),
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=loss_bbox_weight)),
train_cfg=dict(
assigner=dict(
type='BatchDynamicSoftLabelAssigner',
num_classes=num_classes,
topk=13,
topk=dsl_topk,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300),
test_cfg=model_test_cfg,
)

train_pipeline = [
Expand All @@ -104,20 +154,23 @@
type='Mosaic',
img_scale=img_scale,
use_cached=True,
max_cached_images=40,
max_cached_images=mosaic_max_cached_images,
pad_val=114.0),
dict(
type='mmdet.RandomResize',
# img_scale is (width, height)
scale=(img_scale[0] * 2, img_scale[1] * 2),
ratio_range=(0.1, 2.0),
ratio_range=random_resize_ratio_range,
resize_type='mmdet.Resize',
keep_ratio=True),
dict(type='mmdet.RandomCrop', crop_size=img_scale),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type='YOLOv5MixUp', use_cached=True, max_cached_images=20),
dict(
type='YOLOv5MixUp',
use_cached=True,
max_cached_images=mixup_max_cached_images),
dict(type='mmdet.PackDetInputs')
]

Expand All @@ -127,7 +180,7 @@
dict(
type='mmdet.RandomResize',
scale=img_scale,
ratio_range=(0.1, 2.0),
ratio_range=random_resize_ratio_range,
resize_type='mmdet.Resize',
keep_ratio=True),
dict(type='mmdet.RandomCrop', crop_size=img_scale),
Expand Down Expand Up @@ -162,8 +215,8 @@
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline))

Expand All @@ -177,8 +230,8 @@
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img='val2017/'),
ann_file=val_ann_file,
data_prefix=dict(img=val_data_prefix),
test_mode=True,
batch_shapes_cfg=batch_shapes_cfg,
pipeline=test_pipeline))
Expand All @@ -189,22 +242,22 @@
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + 'annotations/instances_val2017.json',
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=weight_decay),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=1.0e-5,
start_factor=lr_start_factor,
by_epoch=False,
begin=0,
end=1000),
Expand All @@ -223,8 +276,8 @@
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=interval,
max_keep_ckpts=3 # only keep latest 3 checkpoints
interval=save_checkpoint_intervals,
max_keep_ckpts=max_keep_ckpts # only keep latest 3 checkpoints
))

custom_hooks = [
Expand All @@ -237,15 +290,15 @@
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - stage2_num_epochs,
switch_epoch=max_epochs - num_epochs_stage2,
switch_pipeline=train_pipeline_stage2)
]

train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=interval,
dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)])
val_interval=save_checkpoint_intervals,
dynamic_intervals=[(max_epochs - num_epochs_stage2, val_interval_stage2)])

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
2 changes: 2 additions & 0 deletions configs/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
_base_ = './rtmdet_l_syncbn_fast_8xb32-300e_coco.py'

# ========================modified parameters======================
deepen_factor = 0.67
widen_factor = 0.75

# =======================Unmodified in most cases==================
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
Expand Down
22 changes: 17 additions & 5 deletions configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
_base_ = './rtmdet_l_syncbn_fast_8xb32-300e_coco.py'
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa

# ========================modified parameters======================
deepen_factor = 0.33
widen_factor = 0.5
img_scale = _base_.img_scale

# ratio range for random resize
random_resize_ratio_range = (0.5, 2.0)
# Number of cached images in mosaic
mosaic_max_cached_images = 40
# Number of cached images in mixup
mixup_max_cached_images = 20

# =======================Unmodified in most cases==================
model = dict(
backbone=dict(
deepen_factor=deepen_factor,
Expand All @@ -30,20 +39,23 @@
type='Mosaic',
img_scale=img_scale,
use_cached=True,
max_cached_images=40,
max_cached_images=mosaic_max_cached_images,
pad_val=114.0),
dict(
type='mmdet.RandomResize',
# img_scale is (width, height)
scale=(img_scale[0] * 2, img_scale[1] * 2),
ratio_range=(0.5, 2.0), # note
ratio_range=random_resize_ratio_range, # note
resize_type='mmdet.Resize',
keep_ratio=True),
dict(type='mmdet.RandomCrop', crop_size=img_scale),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type='YOLOv5MixUp', use_cached=True, max_cached_images=20),
dict(
type='YOLOv5MixUp',
use_cached=True,
max_cached_images=mixup_max_cached_images),
dict(type='mmdet.PackDetInputs')
]

Expand All @@ -53,7 +65,7 @@
dict(
type='mmdet.RandomResize',
scale=img_scale,
ratio_range=(0.5, 2.0), # note
ratio_range=random_resize_ratio_range, # note
resize_type='mmdet.Resize',
keep_ratio=True),
dict(type='mmdet.RandomCrop', crop_size=img_scale),
Expand All @@ -75,6 +87,6 @@
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=_base_.max_epochs - _base_.stage2_num_epochs,
switch_epoch=_base_.max_epochs - _base_.num_epochs_stage2,
switch_pipeline=train_pipeline_stage2)
]
Loading