1+ import warnings
12import torch
23
34from functools import partial
45from torch import nn , Tensor
5- from torch .nn import functional as F
6- from typing import Any , Callable , Dict , List , Optional , Sequence
6+ from typing import Any , Callable , List , Optional , Sequence
77
88from .._internally_replaced_utils import load_state_dict_from_url
9- from torchvision . models .efficientnet import SqueezeExcitation as SElayer
10- from torchvision . models .mobilenetv2 import _make_divisible , ConvBNActivation
9+ from .efficientnet import SqueezeExcitation as SElayer
10+ from .mobilenetv2 import _make_divisible , ConvBNActivation
1111
1212
1313__all__ = ["MobileNetV3" , "mobilenet_v3_large" , "mobilenet_v3_small" ]
2020
2121
2222class SqueezeExcitation (SElayer ):
23+ """DEPRECATED
24+ """
2325 def __init__ (self , input_channels : int , squeeze_factor : int = 4 ):
2426 squeeze_channels = _make_divisible (input_channels // squeeze_factor , 8 )
25- super ().__init__ (input_channels , squeeze_channels , activation = nn . ReLU , scale_activation = nn .Hardsigmoid )
27+ super ().__init__ (input_channels , squeeze_channels , scale_activation = nn .Hardsigmoid )
2628 self .relu = self .activation
29+ warnings .warn (
30+ "This SqueezeExcitation class is deprecated and will be removed in future versions." , FutureWarning )
2731
2832
2933class InvertedResidualConfig :
@@ -47,7 +51,7 @@ def adjust_channels(channels: int, width_mult: float):
4751class InvertedResidual (nn .Module ):
4852 # Implemented as described at section 5 of MobileNetV3 paper
4953 def __init__ (self , cnf : InvertedResidualConfig , norm_layer : Callable [..., nn .Module ],
50- se_layer : Callable [..., nn .Module ] = SqueezeExcitation ):
54+ se_layer : Callable [..., nn .Module ] = partial ( SElayer , scale_activation = nn . Hardsigmoid ) ):
5155 super ().__init__ ()
5256 if not (1 <= cnf .stride <= 2 ):
5357 raise ValueError ('illegal stride value' )
@@ -68,7 +72,8 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
6872 stride = stride , dilation = cnf .dilation , groups = cnf .expanded_channels ,
6973 norm_layer = norm_layer , activation_layer = activation_layer ))
7074 if cnf .use_se :
71- layers .append (se_layer (cnf .expanded_channels ))
75+ squeeze_channels = _make_divisible (cnf .expanded_channels // 4 , 8 )
76+ layers .append (se_layer (cnf .expanded_channels , squeeze_channels ))
7277
7378 # project
7479 layers .append (ConvBNActivation (cnf .expanded_channels , cnf .out_channels , kernel_size = 1 , norm_layer = norm_layer ,
0 commit comments