diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index dd8152dc8f..06f2980219 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -13,9 +13,11 @@ build_upsample_layer, conv_ws_2d, is_norm) # yapf: enable from .resnet import ResNet, make_res_layer -from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init, - fuse_conv_bn, get_model_complexity_info, kaiming_init, - normal_init, uniform_init, xavier_init) +from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit, + PretrainedInit, UniformInit, XavierInit, + bias_init_with_prob, caffe2_xavier_init, constant_init, + fuse_conv_bn, get_model_complexity_info, initialize, + kaiming_init, normal_init, uniform_init, xavier_init) from .vgg import VGG, make_vgg_layer __all__ = [ @@ -30,5 +32,6 @@ 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', - 'MaxPool3d', 'Conv3d' + 'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit', + 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit' ] diff --git a/mmcv/cnn/alexnet.py b/mmcv/cnn/alexnet.py index 4a5c7fbf57..3938d5cd28 100644 --- a/mmcv/cnn/alexnet.py +++ b/mmcv/cnn/alexnet.py @@ -3,8 +3,6 @@ import torch.nn as nn -from ..runner import load_checkpoint - class AlexNet(nn.Module): """AlexNet backbone. @@ -45,6 +43,7 @@ def __init__(self, num_classes=-1): def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = logging.getLogger() + from ..runner import load_checkpoint load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: # use default initializer diff --git a/mmcv/cnn/resnet.py b/mmcv/cnn/resnet.py index e0129eae21..8fe9a3320a 100644 --- a/mmcv/cnn/resnet.py +++ b/mmcv/cnn/resnet.py @@ -4,7 +4,6 @@ import torch.nn as nn import torch.utils.checkpoint as cp -from ..runner import load_checkpoint from .utils import constant_init, kaiming_init @@ -266,6 +265,7 @@ def __init__(self, def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = logging.getLogger() + from ..runner import load_checkpoint load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): diff --git a/mmcv/cnn/utils/__init__.py b/mmcv/cnn/utils/__init__.py index fc96dfcb11..99ec08a786 100644 --- a/mmcv/cnn/utils/__init__.py +++ b/mmcv/cnn/utils/__init__.py @@ -1,12 +1,16 @@ # Copyright (c) Open-MMLab. All rights reserved. from .flops_counter import get_model_complexity_info from .fuse_conv_bn import fuse_conv_bn -from .weight_init import (bias_init_with_prob, caffe2_xavier_init, - constant_init, kaiming_init, normal_init, +from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit, + PretrainedInit, UniformInit, XavierInit, + bias_init_with_prob, caffe2_xavier_init, + constant_init, initialize, kaiming_init, normal_init, uniform_init, xavier_init) __all__ = [ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init', 'constant_init', 'kaiming_init', 'normal_init', 'uniform_init', - 'xavier_init', 'fuse_conv_bn' + 'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS', + 'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', + 'PretrainedInit' ] diff --git a/mmcv/cnn/utils/weight_init.py b/mmcv/cnn/utils/weight_init.py index e6d92e7a8d..042b7ca54d 100644 --- a/mmcv/cnn/utils/weight_init.py +++ b/mmcv/cnn/utils/weight_init.py @@ -2,6 +2,10 @@ import numpy as np import torch.nn as nn +from mmcv.utils import Registry, build_from_cfg, get_logger, print_log + +INITIALIZERS = Registry('initializer') + def constant_init(module, val, bias=0): if hasattr(module, 'weight') and module.weight is not None: @@ -12,22 +16,25 @@ def constant_init(module, val, bias=0): def xavier_init(module, gain=1, bias=0, distribution='normal'): assert distribution in ['uniform', 'normal'] - if distribution == 'uniform': - nn.init.xavier_uniform_(module.weight, gain=gain) - else: - nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'weight') and module.weight is not None: + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def normal_init(module, mean=0, std=1, bias=0): - nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'weight') and module.weight is not None: + nn.init.normal_(module.weight, mean, std) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def uniform_init(module, a=0, b=1, bias=0): - nn.init.uniform_(module.weight, a, b) + if hasattr(module, 'weight') and module.weight is not None: + nn.init.uniform_(module.weight, a, b) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) @@ -39,12 +46,13 @@ def kaiming_init(module, bias=0, distribution='normal'): assert distribution in ['uniform', 'normal'] - if distribution == 'uniform': - nn.init.kaiming_uniform_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) - else: - nn.init.kaiming_normal_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'weight') and module.weight is not None: + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) @@ -57,10 +65,367 @@ def caffe2_xavier_init(module, bias=0): a=1, mode='fan_in', nonlinearity='leaky_relu', + bias=bias, distribution='uniform') def bias_init_with_prob(prior_prob): - """initialize conv/fc bias value according to giving probablity.""" + """initialize conv/fc bias value according to giving probability.""" bias_init = float(-np.log((1 - prior_prob) / prior_prob)) return bias_init + + +class BaseInit(object): + + def __init__(self, bias, bias_prob, layer): + if not isinstance(bias, (int, float)): + raise TypeError(f'bias must be a numbel, but got a {type(bias)}') + + if bias_prob is not None: + if not isinstance(bias_prob, float): + raise TypeError(f'bias_prob type must be float, \ + but got {type(bias_prob)}') + + if layer is not None: + if not isinstance(layer, (str, list)): + raise TypeError(f'layer must be str or list[str], \ + but got a {type(layer)}') + + if bias_prob is not None: + self.bias = bias_init_with_prob(bias_prob) + else: + self.bias = bias + self.layer = [layer] if isinstance(layer, str) else layer + + +@INITIALIZERS.register_module(name='Constant') +class ConstantInit(BaseInit): + """Initialize module parameters with constant values. + + Args: + val (int | float): the value to fill the weights in the module with + bias (int | float): the value to fill the bias or + define initialization type for bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, val, bias=0, bias_prob=None, layer=None): + super().__init__(bias, bias_prob, layer) + self.val = val + + def __call__(self, module): + + def init(m): + if self.layer is None: + constant_init(m, self.val, self.bias) + else: + layername = m.__class__.__name__ + for layer_ in self.layer: + if layername == layer_: + constant_init(m, self.val, self.bias) + + module.apply(init) + + +@INITIALIZERS.register_module(name='Xavier') +class XavierInit(BaseInit): + r"""Initialize module parameters with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks - Glorot, X. & Bengio, Y. (2010). + `_ + + Args: + gain (int | float): an optional scaling factor. Defaults to 1. + bias (int | float): the value to fill the bias or define + initialization type for bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + distribution (str): distribution either be ``'normal'`` + or ``'uniform'``. Defaults to ``'normal'``. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, + gain=1, + bias=0, + bias_prob=None, + distribution='normal', + layer=None): + super().__init__(bias, bias_prob, layer) + self.gain = gain + self.distribution = distribution + + def __call__(self, module): + + def init(m): + if self.layer is None: + xavier_init(m, self.gain, self.bias, self.distribution) + else: + layername = m.__class__.__name__ + for layer_ in self.layer: + if layername == layer_: + xavier_init(m, self.gain, self.bias, self.distribution) + + module.apply(init) + + +@INITIALIZERS.register_module(name='Normal') +class NormalInit(BaseInit): + r"""Initialize module parameters with the values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + mean (int | float):the mean of the normal distribution. Defaults to 0. + std (int | float): the standard deviation of the normal distribution. + Defaults to 1. + bias (int | float): the value to fill the bias or define + initialization type for bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + + """ + + def __init__(self, mean=0, std=1, bias=0, bias_prob=None, layer=None): + super().__init__(bias, bias_prob, layer) + self.mean = mean + self.std = std + + def __call__(self, module): + + def init(m): + if self.layer is None: + normal_init(m, self.mean, self.std, self.bias) + else: + layername = m.__class__.__name__ + for layer_ in self.layer: + if layername == layer_: + normal_init(m, self.mean, self.std, self.bias) + + module.apply(init) + + +@INITIALIZERS.register_module(name='Uniform') +class UniformInit(BaseInit): + r"""Initialize module parameters with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + + Args: + a (int | float): the lower bound of the uniform distribution. + Defaults to 0. + b (int | float): the upper bound of the uniform distribution. + Defaults to 1. + bias (int | float): the value to fill the bias or define + initialization type for bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, a=0, b=1, bias=0, bias_prob=None, layer=None): + super().__init__(bias, bias_prob, layer) + self.a = a + self.b = b + + def __call__(self, module): + + def init(m): + if self.layer is None: + uniform_init(m, self.a, self.b, self.bias) + else: + layername = m.__class__.__name__ + for layer_ in self.layer: + if layername == layer_: + uniform_init(m, self.a, self.b, self.bias) + + module.apply(init) + + +@INITIALIZERS.register_module(name='Kaiming') +class KaimingInit(BaseInit): + r"""Initialize module paramters with the valuse according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification - He, K. et al. (2015). + `_ + + Args: + a (int | float): the negative slope of the rectifier used after this + layer (only used with ``'leaky_relu'``). Defaults to 0. + mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing + ``'fan_in'`` preserves the magnitude of the variance of the weights + in the forward pass. Choosing ``'fan_out'`` preserves the + magnitudes in the backwards pass. Defaults to ``'fan_out'``. + nonlinearity (str): the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` . + Defaults to 'relu'. + bias (int | float): the value to fill the bias or define + initialization type for bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + distribution (str): distribution either be ``'normal'`` or + ``'uniform'``. Defaults to ``'normal'``. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, + a=0, + mode='fan_out', + nonlinearity='relu', + bias=0, + bias_prob=None, + distribution='normal', + layer=None): + super().__init__(bias, bias_prob, layer) + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + self.distribution = distribution + + def __call__(self, module): + + def init(m): + if self.layer is None: + kaiming_init(m, self.a, self.mode, self.nonlinearity, + self.bias, self.distribution) + else: + layername = m.__class__.__name__ + for layer_ in self.layer: + if layername == layer_: + kaiming_init(m, self.a, self.mode, self.nonlinearity, + self.bias, self.distribution) + + module.apply(init) + + +@INITIALIZERS.register_module(name='Pretrained') +class PretrainedInit(object): + """Initialize module by loading a pretrained model + Args: + checkpoint (str): the file should be load + prefix (str, optional): the prefix to indicate the sub-module. + Defaults to None. + """ + + def __init__(self, checkpoint, prefix=None, map_location=None): + self.checkpoint = checkpoint + self.prefix = prefix + self.map_location = map_location + + def __call__(self, module): + from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint, + load_state_dict) + logger = get_logger('mmcv') + if self.prefix is None: + print_log(f'load model from: {self.checkpoint}', logger=logger) + load_checkpoint( + module, + self.checkpoint, + map_location=self.map_location, + strict=False, + logger=logger) + else: + print_log( + f'load {self.prefix} in model from: {self.checkpoint}', + logger=logger) + state_dict = _load_checkpoint_with_prefix( + self.prefix, self.checkpoint, map_location=self.map_location) + load_state_dict(module, state_dict, strict=False, logger=logger) + + +def _initialize(module, cfg): + func = build_from_cfg(cfg, INITIALIZERS) + func(module) + + +def _initialize_override(module, override): + if not isinstance(override, (dict, list)): + raise TypeError( + f'override must be a dict or list, but got {type(override)}') + + override = [override] if isinstance(override, dict) else override + + for override_ in override: + name = override_.pop('name', None) + if hasattr(module, name): + _initialize(getattr(module, name), override_) + else: + raise RuntimeError(f'module did not have attribute {name}') + + +def initialize(module, init_cfg): + """Initialize a module. + + Args: + module (``torch.nn.Module``): the module will be initialized. + init_cfg (dict | list[dict]): initialization configuration dict to + define initializer. OpenMMLab has implemented 7 initializers + including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, + ``Kaiming``, ``Pretrained`` and ``BiasProb`` for bias + initialization. + + Example: + >>> module = nn.Linear(2, 3, bias=True) + >>> init_cfg = dict(type='Constant', val =1 , bias =2) + >>> initialize(module, init_cfg) + + >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2)) + >>> # define key ``'layer'`` for initializing layer with different + >>> # configuration + >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1), + dict(type='Constant', layer='Linear', val=2)] + >>> initialize(module, init_cfg) + + >>> # Omitting ``'layer'`` initialize module with same configuration + >>> init_cfg = dict(type='Constant', val=1, bias=2) + >>> initialize(module, init_cfg) + + >>> # define key``'override'`` to initialize some specific override in + >>> # module + >>> class FooNet(nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.feat = nn.Conv2d(3, 16, 3) + >>> self.reg = nn.Conv2d(16, 10, 3) + >>> self.cls = nn.Conv2d(16, 5, 3) + >>> model = FooNet() + >>> init_cfg = dict(type='Constant', val=1, bias=2, + >>> override=dict(type='Constant', name='reg', val=3, bias=4)) + >>> initialize(model, init_cfg) + + >>> model = ResNet(depth=50) + >>> # Initialize weights with the pretrained model. + >>> init_cfg = dict(type='PretrainedInit', + checkpoint='torchvision://resnet50') + >>> initialize(model, init_cfg) + + >>> # Intialize weights of a sub-module with the specific part of + >>> # a pretrained model by using "prefix". + >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\ + >>> 'retinanet_r50_fpn_1x_coco/'\ + >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' + >>> init_cfg = dict(type='Pretrained', + checkpoint=url, prefix='backbone.') + """ + if not isinstance(init_cfg, (dict, list)): + raise TypeError(f'init_cfg must be a dict, but got {type(init_cfg)}') + + if isinstance(init_cfg, dict): + init_cfg = [init_cfg] + + for cfg in init_cfg: + override = cfg.pop('override', None) + _initialize(module, cfg) + + if override is not None: + _initialize_override(module, override) + else: + # All attributes in module have same initialization. + pass diff --git a/mmcv/cnn/vgg.py b/mmcv/cnn/vgg.py index c20b5d0f4d..82f8ba1093 100644 --- a/mmcv/cnn/vgg.py +++ b/mmcv/cnn/vgg.py @@ -3,7 +3,6 @@ import torch.nn as nn -from ..runner import load_checkpoint from .utils import constant_init, kaiming_init, normal_init @@ -126,6 +125,7 @@ def __init__(self, def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = logging.getLogger() + from ..runner import load_checkpoint load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index 17bfee53e9..df5680ff0b 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -1,7 +1,9 @@ # Copyright (c) Open-MMLab. All rights reserved. +from .base_module import BaseModule from .base_runner import BaseRunner from .builder import RUNNERS, build_runner -from .checkpoint import (CheckpointLoader, _load_checkpoint, load_checkpoint, +from .checkpoint import (CheckpointLoader, _load_checkpoint, + _load_checkpoint_with_prefix, load_checkpoint, load_state_dict, save_checkpoint, weights_to_cpu) from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info, init_dist, master_only) @@ -34,5 +36,5 @@ 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', - 'CheckpointLoader' + 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix' ] diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py new file mode 100644 index 0000000000..f958a66587 --- /dev/null +++ b/mmcv/runner/base_module.py @@ -0,0 +1,53 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import warnings +from abc import ABCMeta + +import torch.nn as nn + + +class BaseModule(nn.Module, metaclass=ABCMeta): + """Base module for all modules in openmmlab.""" + + def __init__(self, init_cfg=None): + """Initialize BaseModule, inherited from `torch.nn.Module` + + Args: + init_cfg (dict, optional): Initialization config dict. + """ + + # NOTE init_cfg can be defined in different levels, but init_cfg + # in low levels has a higher priority. + + super(BaseModule, self).__init__() + # define default value of init_cfg instead of hard code + # in init_weigt() function + self._is_init = False + if init_cfg is not None: + self.init_cfg = init_cfg + + # Backward compatibility in derived classes + # if pretrained is not None: + # warnings.warn('DeprecationWarning: pretrained is a deprecated \ + # key, please consider using init_cfg') + # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + @property + def is_init(self): + return self._is_init + + def init_weight(self): + """Initialize the weights.""" + from ..cnn import initialize + + if not self._is_init: + + if hasattr(self, 'init_cfg'): + initialize(self, self.init_cfg) + self._is_init = True + for module in self.children(): + if 'init_weight' in dir(module): + module.init_weight() + + else: + warnings.warn('This module has bee initialized, \ + please call initialize(module, init_cfg) to reinitialize it') diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index a03f4c0a59..697c9ebdde 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -464,6 +464,39 @@ def _load_checkpoint(filename, map_location=None, logger=None): return CheckpointLoader.load_checkpoint(filename, map_location, logger) +def _load_checkpoint_with_prefix(prefix, filename, map_location=None): + """Load partial pretrained model with specific prefix. + + Args: + prefix (str): The prefix of sub-module. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint = _load_checkpoint(filename, map_location=map_location) + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if not prefix.endswith('.'): + prefix += '.' + prefix_len = len(prefix) + + state_dict = { + k[prefix_len:]: v + for k, v in state_dict.items() if k.startswith(prefix) + } + + assert state_dict, f'{prefix} is not in the pretrained model' + return state_dict + + def load_checkpoint(model, filename, map_location=None, diff --git a/tests/test_cnn/test_weight_init.py b/tests/test_cnn/test_weight_init.py index ebbd00f7ec..c0113b9545 100644 --- a/tests/test_cnn/test_weight_init.py +++ b/tests/test_cnn/test_weight_init.py @@ -1,10 +1,14 @@ # Copyright (c) Open-MMLab. All rights reserved. +from tempfile import TemporaryDirectory + import numpy as np import pytest import torch from torch import nn -from mmcv.cnn import (bias_init_with_prob, caffe2_xavier_init, constant_init, +from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit, + UniformInit, XavierInit, bias_init_with_prob, + caffe2_xavier_init, constant_init, initialize, kaiming_init, normal_init, uniform_init, xavier_init) @@ -75,3 +79,283 @@ def test_bias_init_with_prob(): # TODO: sanity check of weight distribution, e.g. mean, std bias = float(-np.log((1 - prior_prob) / prior_prob)) assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias)) + + +def test_constaninit(): + """test ConstantInit class.""" + model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) + func = ConstantInit(val=1, bias=2, layer='Conv2d') + func(model) + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) + + assert not torch.equal(model[2].weight, + torch.full(model[2].weight.shape, 1.)) + assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) + + func = ConstantInit(val=3, bias_prob=0.01, layer='Linear') + func(model) + res = bias_init_with_prob(0.01) + + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) + assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res)) + + func = ConstantInit(val=4, bias=5) + func(model) + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 4.)) + assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 4.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 5.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 5.)) + + # test bias input type + with pytest.raises(TypeError): + func = ConstantInit(val=1, bias='1') + # test bias_prob type + with pytest.raises(TypeError): + func = ConstantInit(val=1, bias_prob='1') + # test layer input type + with pytest.raises(TypeError): + func = ConstantInit(val=1, layer=1) + + +def test_xavierinit(): + """test XavierInit class.""" + model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) + func = XavierInit(bias=0.1, layer='Conv2d') + func(model) + assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1)) + assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1)) + + constant_func = ConstantInit(val=0, bias=0) + func = XavierInit(gain=100, bias_prob=0.01) + model.apply(constant_func) + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) + assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) + + res = bias_init_with_prob(0.01) + func(model) + assert not torch.equal(model[0].weight, + torch.full(model[0].weight.shape, 0.)) + assert not torch.equal(model[2].weight, + torch.full(model[2].weight.shape, 0.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res)) + + # test bias input type + with pytest.raises(TypeError): + func = XavierInit(bias='0.1', layer='Conv2d') + # test layer inpur type + with pytest.raises(TypeError): + func = XavierInit(bias=0.1, layer=1) + + +def test_normalinit(): + """test Normalinit class.""" + model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) + + func = NormalInit(mean=100, std=1e-5, bias=200) + func(model) + assert model[0].weight.allclose(torch.tensor(100.)) + assert model[2].weight.allclose(torch.tensor(100.)) + assert model[0].bias.allclose(torch.tensor(200.)) + assert model[2].bias.allclose(torch.tensor(200.)) + + func = NormalInit( + mean=300, std=1e-5, bias_prob=0.01, layer=['Conv2d', 'Linear']) + res = bias_init_with_prob(0.01) + func(model) + assert model[0].weight.allclose(torch.tensor(300.)) + assert model[2].weight.allclose(torch.tensor(300.)) + assert model[0].bias.allclose(torch.tensor(res)) + assert model[2].bias.allclose(torch.tensor(res)) + + +def test_uniforminit(): + """"test UniformInit class.""" + model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) + func = UniformInit(a=1, b=1, bias=2) + func(model) + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) + assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) + + func = UniformInit(a=100, b=100, layer=['Conv2d', 'Linear'], bias=10) + func(model) + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, + 100.)) + assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, + 100.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) + + +def test_kaiminginit(): + """test KaimingInit class.""" + model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) + func = KaimingInit(bias=0.1, layer='Conv2d') + func(model) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1)) + assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1)) + + func = KaimingInit(a=100, bias=10) + constant_func = ConstantInit(val=0, bias=0) + model.apply(constant_func) + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) + assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) + + func(model) + assert not torch.equal(model[0].weight, + torch.full(model[0].weight.shape, 0.)) + assert not torch.equal(model[2].weight, + torch.full(model[2].weight.shape, 0.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) + + +class FooModule(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 2) + self.conv2d = nn.Conv2d(3, 1, 3) + self.conv2d_2 = nn.Conv2d(3, 2, 3) + + +def test_pretrainedinit(): + """test PretrainedInit class.""" + + modelA = FooModule() + constant_func = ConstantInit(val=1, bias=2) + modelA.apply(constant_func) + modelB = FooModule() + funcB = PretrainedInit(checkpoint='modelA.pth') + modelC = nn.Linear(1, 2) + funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.') + with TemporaryDirectory(): + torch.save(modelA.state_dict(), 'modelA.pth') + funcB(modelB) + assert torch.equal(modelB.linear.weight, + torch.full(modelB.linear.weight.shape, 1.)) + assert torch.equal(modelB.linear.bias, + torch.full(modelB.linear.bias.shape, 2.)) + assert torch.equal(modelB.conv2d.weight, + torch.full(modelB.conv2d.weight.shape, 1.)) + assert torch.equal(modelB.conv2d.bias, + torch.full(modelB.conv2d.bias.shape, 2.)) + assert torch.equal(modelB.conv2d_2.weight, + torch.full(modelB.conv2d_2.weight.shape, 1.)) + assert torch.equal(modelB.conv2d_2.bias, + torch.full(modelB.conv2d_2.bias.shape, 2.)) + + funcC(modelC) + assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.)) + assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.)) + + +def test_initialize(): + model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) + foonet = FooModule() + + init_cfg = dict(type='Constant', val=1, bias=2) + initialize(model, init_cfg) + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) + assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) + + init_cfg = [ + dict(type='Constant', layer='Conv1d', val=1, bias=2), + dict(type='Constant', layer='Linear', val=3, bias=4) + ] + initialize(model, init_cfg) + assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) + assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.)) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) + assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.)) + + init_cfg = dict( + type='Constant', + val=1, + bias=2, + layer=['Conv2d', 'Linear'], + override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) + initialize(foonet, init_cfg) + assert torch.equal(foonet.linear.weight, + torch.full(foonet.linear.weight.shape, 1.)) + assert torch.equal(foonet.linear.bias, + torch.full(foonet.linear.bias.shape, 2.)) + assert torch.equal(foonet.conv2d.weight, + torch.full(foonet.conv2d.weight.shape, 1.)) + assert torch.equal(foonet.conv2d.bias, + torch.full(foonet.conv2d.bias.shape, 2.)) + assert torch.equal(foonet.conv2d_2.weight, + torch.full(foonet.conv2d_2.weight.shape, 3.)) + assert torch.equal(foonet.conv2d_2.bias, + torch.full(foonet.conv2d_2.bias.shape, 4.)) + + init_cfg = dict( + type='Pretrained', + checkpoint='modelA.pth', + override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) + modelA = FooModule() + constant_func = ConstantInit(val=1, bias=2) + modelA.apply(constant_func) + with TemporaryDirectory(): + torch.save(modelA.state_dict(), 'modelA.pth') + initialize(foonet, init_cfg) + assert torch.equal(foonet.linear.weight, + torch.full(foonet.linear.weight.shape, 1.)) + assert torch.equal(foonet.linear.bias, + torch.full(foonet.linear.bias.shape, 2.)) + assert torch.equal(foonet.conv2d.weight, + torch.full(foonet.conv2d.weight.shape, 1.)) + assert torch.equal(foonet.conv2d.bias, + torch.full(foonet.conv2d.bias.shape, 2.)) + assert torch.equal(foonet.conv2d_2.weight, + torch.full(foonet.conv2d_2.weight.shape, 3.)) + assert torch.equal(foonet.conv2d_2.bias, + torch.full(foonet.conv2d_2.bias.shape, 4.)) + # test init_cfg type + with pytest.raises(TypeError): + init_cfg = 'init_cfg' + initialize(foonet, init_cfg) + + # test override value type + with pytest.raises(TypeError): + init_cfg = dict( + type='Constant', + val=1, + bias=2, + layer=['Conv2d', 'Linear'], + override='conv') + initialize(foonet, init_cfg) + + # test override name + with pytest.raises(RuntimeError): + init_cfg = dict( + type='Constant', + val=1, + bias=2, + layer=['Conv2d', 'Linear'], + override=dict(type='Constant', name='conv2d_3', val=3, bias=4)) + initialize(foonet, init_cfg) + + # test list override name + with pytest.raises(RuntimeError): + init_cfg = dict( + type='Constant', + val=1, + bias=2, + layer=['Conv2d', 'Linear'], + override=[ + dict(type='Constant', name='conv2d', val=3, bias=4), + dict(type='Constant', name='conv2d_3', val=5, bias=6) + ]) + initialize(foonet, init_cfg) diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py new file mode 100644 index 0000000000..5250077895 --- /dev/null +++ b/tests/test_runner/test_basemodule.py @@ -0,0 +1,228 @@ +import torch +from torch import nn + +from mmcv.runner import BaseModule +from mmcv.utils import Registry, build_from_cfg + +COMPONENTS = Registry('component') +FOOMODELS = Registry('model') + + +@COMPONENTS.register_module() +class FooConv1d(BaseModule): + + def __init__(self, init_cfg=None): + super().__init__(init_cfg) + self.conv1d = nn.Conv1d(4, 1, 4) + + def forward(self, x): + return self.conv1d(x) + + +@COMPONENTS.register_module() +class FooConv2d(BaseModule): + + def __init__(self, init_cfg=None): + super().__init__(init_cfg) + self.conv2d = nn.Conv2d(3, 1, 3) + + def forward(self, x): + return self.conv2d(x) + + +@COMPONENTS.register_module() +class FooLinear(BaseModule): + + def __init__(self, init_cfg=None): + super().__init__(init_cfg) + self.linear = nn.Linear(3, 4) + + def forward(self, x): + return self.linear(x) + + +@COMPONENTS.register_module() +class FooLinearConv1d(BaseModule): + + def __init__(self, linear=None, conv1d=None, init_cfg=None): + super().__init__(init_cfg) + if linear is not None: + self.linear = build_from_cfg(linear, COMPONENTS) + if conv1d is not None: + self.conv1d = build_from_cfg(conv1d, COMPONENTS) + + def forward(self, x): + x = self.linear(x) + return self.conv1d(x) + + +@FOOMODELS.register_module() +class FooModel(BaseModule): + + def __init__(self, + component1=None, + component2=None, + component3=None, + component4=None, + init_cfg=None) -> None: + super().__init__(init_cfg) + if component1 is not None: + self.component1 = build_from_cfg(component1, COMPONENTS) + if component2 is not None: + self.component2 = build_from_cfg(component2, COMPONENTS) + if component3 is not None: + self.component3 = build_from_cfg(component3, COMPONENTS) + if component4 is not None: + self.component4 = build_from_cfg(component4, COMPONENTS) + + # its type is not BaseModule, it can be initialized + # with "override" key. + self.reg = nn.Linear(3, 4) + + +def test_model_weight_init(): + """ + Config + model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, + Conv2d: weight=5, bias=6) + ├──component1 (FooConv1d) + ├──component2 (FooConv2d) + ├──component3 (FooLinear) + ├──component4 (FooLinearConv1d) + ├──linear (FooLinear) + ├──conv1d (FooConv1d) + ├──reg (nn.Linear) + + Parameters after initialization + model (FooModel) + ├──component1 (FooConv1d, weight=3, bias=4) + ├──component2 (FooConv2d, weight=5, bias=6) + ├──component3 (FooLinear, weight=1, bias=2) + ├──component4 (FooLinearConv1d) + ├──linear (FooLinear, weight=1, bias=2) + ├──conv1d (FooConv1d, weight=3, bias=4) + ├──reg (nn.Linear, weight=1, bias=2) + """ + model_cfg = dict( + type='FooModel', + init_cfg=[ + dict(type='Constant', val=1, bias=2, layer='Linear'), + dict(type='Constant', val=3, bias=4, layer='Conv1d'), + dict(type='Constant', val=5, bias=6, layer='Conv2d') + ], + component1=dict(type='FooConv1d'), + component2=dict(type='FooConv2d'), + component3=dict(type='FooLinear'), + component4=dict( + type='FooLinearConv1d', + linear=dict(type='FooLinear'), + conv1d=dict(type='FooConv1d'))) + + model = build_from_cfg(model_cfg, FOOMODELS) + model.init_weight() + + assert torch.equal(model.component1.conv1d.weight, + torch.full(model.component1.conv1d.weight.shape, 3.0)) + assert torch.equal(model.component1.conv1d.bias, + torch.full(model.component1.conv1d.bias.shape, 4.0)) + assert torch.equal(model.component2.conv2d.weight, + torch.full(model.component2.conv2d.weight.shape, 5.0)) + assert torch.equal(model.component2.conv2d.bias, + torch.full(model.component2.conv2d.bias.shape, 6.0)) + assert torch.equal(model.component3.linear.weight, + torch.full(model.component3.linear.weight.shape, 1.0)) + assert torch.equal(model.component3.linear.bias, + torch.full(model.component3.linear.bias.shape, 2.0)) + assert torch.equal( + model.component4.linear.linear.weight, + torch.full(model.component4.linear.linear.weight.shape, 1.0)) + assert torch.equal( + model.component4.linear.linear.bias, + torch.full(model.component4.linear.linear.bias.shape, 2.0)) + assert torch.equal( + model.component4.conv1d.conv1d.weight, + torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0)) + assert torch.equal( + model.component4.conv1d.conv1d.bias, + torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0)) + assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, + 1.0)) + assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0)) + + +def test_nest_components_weight_init(): + """ + Config + model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, + Conv2d: weight=5, bias=6) + ├──component1 (FooConv1d, Conv1d: weight=7, bias=8) + ├──component2 (FooConv2d, Conv2d: weight=9, bias=10) + ├──component3 (FooLinear) + ├──component4 (FooLinearConv1d, Linear: weight=11, bias=12) + ├──linear (FooLinear, Linear: weight=11, bias=12) + ├──conv1d (FooConv1d) + ├──reg (nn.Linear, weight=13, bias=14) + + Parameters after initialization + model (FooModel) + ├──component1 (FooConv1d, weight=7, bias=8) + ├──component2 (FooConv2d, weight=9, bias=10) + ├──component3 (FooLinear, weight=1, bias=2) + ├──component4 (FooLinearConv1d) + ├──linear (FooLinear, weight=1, bias=2) + ├──conv1d (FooConv1d, weight=3, bias=4) + ├──reg (nn.Linear, weight=13, bias=14) + """ + + model_cfg = dict( + type='FooModel', + init_cfg=[ + dict( + type='Constant', + val=1, + bias=2, + layer='Linear', + override=dict(type='Constant', name='reg', val=13, bias=14)), + dict(type='Constant', val=3, bias=4, layer='Conv1d'), + dict(type='Constant', val=5, bias=6, layer='Conv2d'), + ], + component1=dict( + type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)), + component2=dict( + type='FooConv2d', init_cfg=dict(type='Constant', val=9, bias=10)), + component3=dict(type='FooLinear'), + component4=dict( + type='FooLinearConv1d', + linear=dict(type='FooLinear'), + conv1d=dict(type='FooConv1d'))) + + model = build_from_cfg(model_cfg, FOOMODELS) + model.init_weight() + + assert torch.equal(model.component1.conv1d.weight, + torch.full(model.component1.conv1d.weight.shape, 7.0)) + assert torch.equal(model.component1.conv1d.bias, + torch.full(model.component1.conv1d.bias.shape, 8.0)) + assert torch.equal(model.component2.conv2d.weight, + torch.full(model.component2.conv2d.weight.shape, 9.0)) + assert torch.equal(model.component2.conv2d.bias, + torch.full(model.component2.conv2d.bias.shape, 10.0)) + assert torch.equal(model.component3.linear.weight, + torch.full(model.component3.linear.weight.shape, 1.0)) + assert torch.equal(model.component3.linear.bias, + torch.full(model.component3.linear.bias.shape, 2.0)) + assert torch.equal( + model.component4.linear.linear.weight, + torch.full(model.component4.linear.linear.weight.shape, 1.0)) + assert torch.equal( + model.component4.linear.linear.bias, + torch.full(model.component4.linear.linear.bias.shape, 2.0)) + assert torch.equal( + model.component4.conv1d.conv1d.weight, + torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0)) + assert torch.equal( + model.component4.conv1d.conv1d.bias, + torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0)) + assert torch.equal(model.reg.weight, + torch.full(model.reg.weight.shape, 13.0)) + assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0)) diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index c30238b7aa..5bffb8cd4b 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -1,5 +1,6 @@ import sys from collections import OrderedDict +from tempfile import TemporaryDirectory from unittest.mock import MagicMock import pytest @@ -8,7 +9,8 @@ from torch.nn.parallel import DataParallel from mmcv.parallel.registry import MODULE_WRAPPERS -from mmcv.runner.checkpoint import get_state_dict, load_from_pavi +from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix, + get_state_dict, load_from_pavi) @MODULE_WRAPPERS.register_module() @@ -138,6 +140,7 @@ def test_get_state_dict(): def test_load_pavimodel_dist(): + sys.modules['pavi'] = MagicMock() sys.modules['pavi.modelcloud'] = MagicMock() pavimodel = Mockpavimodel() @@ -152,10 +155,45 @@ def test_load_pavimodel_dist(): _ = load_from_pavi('pavi://checkpoint.pth') +def test_load_checkpoint_with_prefix(): + + class FooModule(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 2) + self.conv2d = nn.Conv2d(3, 1, 3) + self.conv2d_2 = nn.Conv2d(3, 2, 3) + + model = FooModule() + nn.init.constant_(model.linear.weight, 1) + nn.init.constant_(model.linear.bias, 2) + nn.init.constant_(model.conv2d.weight, 3) + nn.init.constant_(model.conv2d.bias, 4) + nn.init.constant_(model.conv2d_2.weight, 5) + nn.init.constant_(model.conv2d_2.bias, 6) + + with TemporaryDirectory(): + torch.save(model.state_dict(), 'model.pth') + prefix = 'conv2d' + state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth') + assert torch.equal(model.conv2d.state_dict()['weight'], + state_dict['weight']) + assert torch.equal(model.conv2d.state_dict()['bias'], + state_dict['bias']) + + # test whether prefix is in pretrained model + with pytest.raises(AssertionError): + prefix = 'back' + _load_checkpoint_with_prefix(prefix, 'model.pth') + + def test_load_classes_name(): - from mmcv.runner import load_checkpoint, save_checkpoint - import tempfile import os + + import tempfile + + from mmcv.runner import load_checkpoint, save_checkpoint checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') model = Model() save_checkpoint(model, checkpoint_path)