Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactoring] Unified parameters initialization #780

Merged
merged 25 commits into from
Feb 7, 2021

Conversation

MeowZheng
Copy link
Collaborator

@MeowZheng MeowZheng commented Jan 7, 2021

I mainly revised 3 files:

  1. In weight_init.py, add Constant, Kaiming, Normal, Pretrained, Uniform, Xavier classes, and register them in "INITIALIZERS" registry; add initialize function to initialize parameters with "init_cfg".
  2. In checkpoint.py, add _load_checkpoint_with_prefix function.
  3. Add base_module.py, add "BaseModule" and only implement init_weight function for parameters initialization

Design

Model intialization in OpenMMLab uses init_cfg, BaseModule::init_weight, initialize, and INITIALIZERS registry together. Users can initialize their models with following two steps:

  1. Define init_cfg for a model or its components in model_cfg, but init_cfg of children components have higher priority and will override init_cfg of parents modules.
  2. Build model as usual, but call model.init_weight() method explicitly, and model parameters will be initialized as configuration.

The high-level workflow of initialization in OpenMMLab is:
model_cfg(init_cfg) -> build_from_cfg -> model -> init_weight() -> initialize(self, self.init_cfg) -> children's init_weight()

APIs

init_cfg

it is dict or list[dict], and contains:

  • type - str containing the initializer name in INTIALIZERS, and followed by arguments of the initializer.
  • layer - str or list[str] containing the names of baisc layers in Pytorch or MMCV with learnable parameters that will be initialized, e.g. 'Conv2d','DeformConv2d'.
  • override - dict or [dict] containing the sub-modules that not inherit from BaseModule and whose initialization configuration is different from other layers' which are in 'layer' key. Initializer defined in type will work for all layers defined in layer, so if sub-modules are not derived Classes of BaseModule but can be initialized as same ways of layers in layer, it does not need to use override. override contains:
    • type followed by arguments of initializer;
    • name to indicate sub-module which will be initialized.

BaseModule

BaseModule is the base module for all modules in OpenMMLab. init_weight method of BaseModule can initialize itself parameters using initialize(module, init_cfg) function in mmcv, and call sub-components' init_weight()method.

initialize(module, init_cfg)

  • module - the module will be initialized.
  • init_cfg - initialization configuration dict.

INITIALIZERS registry

OpenMMLab has implemented 7 initializers including Constant, Xavier, Normal, Uniform, Kaiming, and Pretrained, and registers them in INITIALIZERS

Taking advantage of the "buider&registry" mechanism of OpenMMLab, INITIALIZERS can be easily extended by implementing new initializer classes and registering them in INITIALIZERS.

Usages

users initialize models of OpenMMLab, just need two steps: 1. define init_cfg; 2. build model and call model.init_weight().

define init_cfg for model

FooModel, FooConv1d, FooConv2d and FooLinear are derived from BaseModule. If we would like to initialize all weight of linear layer as 1 and bias as 2, all weight of conv1d layer as 3 and bias as 4, all weight of conv2d layer as 5 and bias as 6 of FooModel, we can define model_cfg and init_cfg as following

model_cfg = dict(
    type="FooModel",
    init_cfg=[
        dict(type='Constant', val=1, bias=2, layer='Linear'),
        dict(type='Constant', val=3, bias=4, layer='Conv1d'),
        dict(type='Constant', val=5, bias=6, layer='Conv2d')
    ],
    component1=dict(type='FooConv1d'),
    component2=dict(type='FooConv2d'),
    component3=dict(type='FooLinear'),
    component4=dict(
        type='FooLinearConv1d',
        linear=dict(type='FooLinear'),
        conv1d=dict(type='FooConv1d')))

After this, we build a FooModel instance and call init_weight

model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()

define init_cfg nestedly

init_cfg of sub-modules will override the parents', like:

model_cfg = dict(
    type="FooModel",
    init_cfg=[
        dict(type='Constant', val=1, bias=2,layer='Linear',
            override=dict(type='Constant', name='reg', val=13, bias=14)),
        dict(type='Constant', val=3, bias=4, layer='Conv1d'),
        dict(type='Constant', val=5, bias=6, layer='Conv2d'),
    ],
    component1=dict(
        type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)),
    component2=dict(
        type='FooConv2d', init_cfg=dict(type='Constant', val=9, bias=10)),
    component3=dict(type='FooLinear'),
    component4=dict(
        type='FooLinearConv1d',
        linear=dict(type='FooLinear'),
        conv1d=dict(type='FooConv1d')))

after model = build_from_cfg(model_cfg, FOOMODELS) and model.init_weight(), parameters will be
model (FooModel)

  • component1 (FooConv1d, weight=7, bias=8)
  • component2 (FooConv2d, weight=9, bias=10)
  • component3 (FooLinear, weight=1, bias=2)
  • component4 (FooLinearConv1d)
    • linear (FooLinear, weight=1, bias=2)
    • conv1d (FooConv1d, weight=3, bias=4)
  • reg (nn.Linear, weight=13, bias=14)

Migration

  1. If models inherit from nn.Module, must inherit from BaseModule.
  2. Add init_cfg argument in __init__ of derived classes, and set default value for init_cfg:
    If init_weight in current classes is recursively called init_weight of children's modules, such as
def init_weights(self, pretrained):
    """Initialize the weights in head.

    Args:
        pretrained (str, optional): Path to pre-trained weights.
            Defaults to None.
    """
    if self.with_shared_head:
        self.shared_head.init_weights(pretrained=pretrained)
    if self.with_bbox:
        self.bbox_roi_extractor.init_weights()
        self.bbox_head.init_weights()
    if self.with_mask:
        self.mask_head.init_weights()
        if not self.share_roi_extractor:
            self.mask_roi_extractor.init_weights()

just set init_cfg = None. Otherwise, set init_cfg value according to current code in init_weight, e.g.

# init_weight from retina_head
def init_weights(self):
    """Initialize weights of the head."""
    for m in self.cls_convs:
        normal_init(m.conv, std=0.01)
    for m in self.reg_convs:
        normal_init(m.conv, std=0.01)
    bias_cls = bias_init_with_prob(0.01)
    normal_init(self.retina_cls, std=0.01, bias=bias_cls)
    normal_init(self.retina_reg, std=0.01)

the init_cfg must be

init_cfg = dict(
    type='Normal',
    layer='Conv2d',
    std=0.01,
    override=dict(type='Normal', name='retina_cls',std=0.01,
                bias_prob=0.01))
  1. Backward compatibility for pretrained in previous config file. Add following code in __init__ of derived classes from BaseModule
if pretrained is not None:
    warnings.warn('DeprecationWarning: pretrained is a deprecated \
        key, please consider using init_cfg')
    self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  1. There is no need to reimplement init_weight method in derived classes.
  2. Call model.init_weight(), after building models. Please pay attention to it, as this is an additional action for models in OpenMMLab.
  3. If users call init_weight of sub-components, or call init_weight of model twice, there will be a warning "This module has been initialized, please call initialize(module, init_cfg) to reinitialize it".

BC-breaking

Please inform users to call model.init_weight() after building models in tutorals.  

@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Jan 7, 2021

May add design, usages, migration, and BC-breakings in the PR messages and documentation for both discussion and reference for users.

@ZwwWayne ZwwWayne requested a review from xvjiarui January 7, 2021 14:30
@codecov
Copy link

codecov bot commented Jan 15, 2021

Codecov Report

Merging #780 (45a8746) into master (6c57b88) will increase coverage by 0.69%.
The diff coverage is 86.56%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #780      +/-   ##
==========================================
+ Coverage   62.23%   62.93%   +0.69%     
==========================================
  Files         144      145       +1     
  Lines        8506     8673     +167     
  Branches     1522     1569      +47     
==========================================
+ Hits         5294     5458     +164     
- Misses       2945     2950       +5     
+ Partials      267      265       -2     
Flag Coverage Δ
unittests 62.93% <86.56%> (+0.69%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcv/cnn/alexnet.py 26.08% <0.00%> (-4.35%) ⬇️
mmcv/cnn/resnet.py 12.19% <0.00%> (-0.61%) ⬇️
mmcv/cnn/vgg.py 11.11% <0.00%> (-1.02%) ⬇️
mmcv/onnx/onnx_utils/symbolic_helper.py 0.00% <0.00%> (ø)
mmcv/ops/nms.py 34.43% <8.33%> (ø)
mmcv/utils/parrots_jit.py 78.94% <66.66%> (+2.47%) ⬆️
mmcv/runner/checkpoint.py 68.05% <81.81%> (+1.82%) ⬆️
mmcv/runner/base_module.py 85.71% <85.71%> (ø)
mmcv/cnn/utils/weight_init.py 98.80% <98.59%> (+1.75%) ⬆️
mmcv/cnn/__init__.py 100.00% <100.00%> (ø)
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6c57b88...2b89df3. Read the comment docs.

@ZwwWayne
Copy link
Collaborator

The contents of PR messages should also be put into the tutorial to serve as documentation.

@MeowZheng MeowZheng requested a review from ZwwWayne January 18, 2021 09:40
@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Feb 2, 2021

Please resolve conflicts as the load_checkpoint has been refactored.

mmcv/cnn/utils/weight_init.py Outdated Show resolved Hide resolved
mmcv/cnn/utils/weight_init.py Show resolved Hide resolved
mmcv/cnn/utils/weight_init.py Outdated Show resolved Hide resolved
tests/test_runner/test_checkpoint.py Show resolved Hide resolved
mmcv/cnn/resnet.py Show resolved Hide resolved
mmcv/cnn/utils/weight_init.py Outdated Show resolved Hide resolved
@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Feb 4, 2021

LGTM now. See if @hellock has any comments.

@MeowZheng MeowZheng requested a review from hellock February 5, 2021 10:27
@apanand14
Copy link

File "tools/train.py", line 163, in main
model.init_weights()
File "C:\Users\topseven\anaconda3\envs\mmcv\lib\site-packages\mmcv\runner\base_module.py", line 117, in init_weights
m.init_weights()
TypeError: init_weights() missing 1 required positional argument: 'pretrained'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants