-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor]: Unified parameter initialization #622
Changes from 53 commits
b896597
010024e
d52c837
d317b7c
d42898e
be73dd2
f1838d7
e2c74a2
1781db5
a649d1a
487823d
2fa3dd6
69c6ce6
7ba358f
9301014
953a4b7
b824fcb
30e8195
613f503
46afb23
458ec22
e9f6630
af9524f
87b04f6
3f99ed3
94924ad
188da1e
33292c0
2f5382c
455d2de
d443ad4
7f391e1
6f926d7
3fad6aa
15d9188
1eb9274
a0eecb9
dc3f2fb
6439d61
20dfaba
4a5857c
ff56ae9
8830d63
1082498
d91e120
ca6e908
68eec71
37f4aee
e2dc867
17c2485
d24204e
b5f2da2
0e43d8d
19c2ef3
245aea0
85a2fbe
8315abe
e9aaeec
6156cf9
2d3a88b
e142171
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ def digit_version(version_str): | |
return digit_version | ||
|
||
|
||
mmcv_minimum_version = '1.3.1' | ||
mmcv_minimum_version = '1.3.5' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should set 1.3.8 |
||
mmcv_maximum_version = '1.4.0' | ||
mmcv_version = digit_version(mmcv.__version__) | ||
|
||
|
@@ -27,17 +27,17 @@ def digit_version(version_str): | |
f'MMCV=={mmcv.__version__} is used but incompatible. ' \ | ||
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' | ||
|
||
mmdet_minimum_version = '2.10.0' | ||
mmdet_maximum_version = '2.11.0' | ||
mmdet_minimum_version = '2.12.0' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be 2.14.0 then. |
||
mmdet_maximum_version = '3.0.0' | ||
mmdet_version = digit_version(mmdet.__version__) | ||
assert (mmdet_version >= digit_version(mmdet_minimum_version) | ||
and mmdet_version <= digit_version(mmdet_maximum_version)), \ | ||
f'MMDET=={mmdet.__version__} is used but incompatible. ' \ | ||
f'Please install mmdet>={mmdet_minimum_version}, ' \ | ||
f'<={mmdet_maximum_version}.' | ||
|
||
mmseg_minimum_version = '0.14.0' | ||
mmseg_maximum_version = '0.14.0' | ||
mmseg_minimum_version = '0.14.1' | ||
mmseg_maximum_version = '0.15.0' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. may set 1.0.0 for now. Usually there is not many BC-breakings. |
||
mmseg_version = digit_version(mmseg.__version__) | ||
assert (mmseg_version >= digit_version(mmseg_minimum_version) | ||
and mmseg_version <= digit_version(mmseg_maximum_version)), \ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
import warnings | ||
from mmcv.cnn import build_conv_layer, build_norm_layer | ||
from mmcv.runner import load_checkpoint | ||
from mmcv.runner import BaseModule | ||
from torch import nn as nn | ||
|
||
from mmdet.models import BACKBONES | ||
|
||
|
||
@BACKBONES.register_module() | ||
class SECOND(nn.Module): | ||
class SECOND(BaseModule): | ||
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet. | ||
|
||
Args: | ||
|
@@ -24,8 +25,10 @@ def __init__(self, | |
layer_nums=[3, 5, 5], | ||
layer_strides=[2, 2, 2], | ||
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), | ||
conv_cfg=dict(type='Conv2d', bias=False)): | ||
super(SECOND, self).__init__() | ||
conv_cfg=dict(type='Conv2d', bias=False), | ||
init_cfg=None, | ||
pretrained=None): | ||
super(SECOND, self).__init__(init_cfg=init_cfg) | ||
assert len(layer_strides) == len(layer_nums) | ||
assert len(out_channels) == len(layer_nums) | ||
|
||
|
@@ -61,14 +64,14 @@ def __init__(self, | |
|
||
self.blocks = nn.ModuleList(blocks) | ||
|
||
def init_weights(self, pretrained=None): | ||
"""Initialize weights of the 2D backbone.""" | ||
# Do not initialize the conv layers | ||
# to follow the original implementation | ||
assert not (init_cfg and pretrained), \ | ||
'init_cfg and pretrained cannot be setting at the same time' | ||
if isinstance(pretrained, str): | ||
from mmdet3d.utils import get_root_logger | ||
logger = get_root_logger() | ||
load_checkpoint(self, pretrained, strict=False, logger=logger) | ||
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' | ||
'please use "init_cfg" instead') | ||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) | ||
else: | ||
self.init_cfg = dict(type='Kaiming', layer='Conv2d') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to L66-67, why use Kaiming init here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we dont use the pretrained model, we want to use kaiming_init to init all Conv2d layer. |
||
|
||
def forward(self, x): | ||
"""Forward function. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change #378 to #622