Skip to content

Commit 72cecb1

Browse files
committed
Deprecating the mobilenetv3.SqueezeExcitation layer.
1 parent 62dc50a commit 72cecb1

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

torchvision/models/mobilenetv3.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
import warnings
12
import torch
23

34
from functools import partial
45
from torch import nn, Tensor
5-
from torch.nn import functional as F
6-
from typing import Any, Callable, Dict, List, Optional, Sequence
6+
from typing import Any, Callable, List, Optional, Sequence
77

88
from .._internally_replaced_utils import load_state_dict_from_url
9-
from torchvision.models.efficientnet import SqueezeExcitation as SElayer
10-
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
9+
from .efficientnet import SqueezeExcitation as SElayer
10+
from .mobilenetv2 import _make_divisible, ConvBNActivation
1111

1212

1313
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
@@ -20,10 +20,14 @@
2020

2121

2222
class SqueezeExcitation(SElayer):
23+
"""DEPRECATED
24+
"""
2325
def __init__(self, input_channels: int, squeeze_factor: int = 4):
2426
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
25-
super().__init__(input_channels, squeeze_channels, activation=nn.ReLU, scale_activation=nn.Hardsigmoid)
27+
super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
2628
self.relu = self.activation
29+
warnings.warn(
30+
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
2731

2832

2933
class InvertedResidualConfig:
@@ -47,7 +51,7 @@ def adjust_channels(channels: int, width_mult: float):
4751
class InvertedResidual(nn.Module):
4852
# Implemented as described at section 5 of MobileNetV3 paper
4953
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
50-
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
54+
se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)):
5155
super().__init__()
5256
if not (1 <= cnf.stride <= 2):
5357
raise ValueError('illegal stride value')
@@ -68,7 +72,8 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
6872
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
6973
norm_layer=norm_layer, activation_layer=activation_layer))
7074
if cnf.use_se:
71-
layers.append(se_layer(cnf.expanded_channels))
75+
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
76+
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
7277

7378
# project
7479
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,

torchvision/models/quantization/mobilenetv3.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
from torch import nn, Tensor
33
from ..._internally_replaced_utils import load_state_dict_from_url
4-
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
5-
SqueezeExcitation, model_urls, _mobilenet_v3_conf
4+
from ..efficientnet import SqueezeExcitation as SElayer
5+
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
6+
model_urls, _mobilenet_v3_conf
67
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
78
from typing import Any, List, Optional
89
from .utils import _replace_relu
@@ -16,7 +17,7 @@
1617
}
1718

1819

19-
class QuantizableSqueezeExcitation(SqueezeExcitation):
20+
class QuantizableSqueezeExcitation(SElayer):
2021
def __init__(self, *args: Any, **kwargs: Any) -> None:
2122
super().__init__(*args, **kwargs)
2223
self.skip_mul = nn.quantized.FloatFunctional()
@@ -25,7 +26,7 @@ def forward(self, input: Tensor) -> Tensor:
2526
return self.skip_mul.mul(self._scale(input), input)
2627

2728
def fuse_model(self) -> None:
28-
fuse_modules(self, ['fc1', 'relu'], inplace=True)
29+
fuse_modules(self, ['fc1', 'activation'], inplace=True)
2930

3031

3132
class QuantizableInvertedResidual(InvertedResidual):

0 commit comments

Comments
 (0)