Skip to content

Commit 62dc50a

Browse files
committed
Reuse EfficientNet SE layer.
1 parent 3e27eb2 commit 62dc50a

File tree

2 files changed

+5
-18
lines changed

2 files changed

+5
-18
lines changed

torchvision/models/mobilenetv3.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Callable, Dict, List, Optional, Sequence
77

88
from .._internally_replaced_utils import load_state_dict_from_url
9+
from torchvision.models.efficientnet import SqueezeExcitation as SElayer
910
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
1011

1112

@@ -18,25 +19,11 @@
1819
}
1920

2021

21-
class SqueezeExcitation(nn.Module):
22-
# Implemented as described at Figure 4 of the MobileNetV3 paper
22+
class SqueezeExcitation(SElayer):
2323
def __init__(self, input_channels: int, squeeze_factor: int = 4):
24-
super().__init__()
2524
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
26-
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
27-
self.relu = nn.ReLU(inplace=True)
28-
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
29-
30-
def _scale(self, input: Tensor, inplace: bool) -> Tensor:
31-
scale = F.adaptive_avg_pool2d(input, 1)
32-
scale = self.fc1(scale)
33-
scale = self.relu(scale)
34-
scale = self.fc2(scale)
35-
return F.hardsigmoid(scale, inplace=inplace)
36-
37-
def forward(self, input: Tensor) -> Tensor:
38-
scale = self._scale(input, True)
39-
return scale * input
25+
super().__init__(input_channels, squeeze_channels, activation=nn.ReLU, scale_activation=nn.Hardsigmoid)
26+
self.relu = self.activation
4027

4128

4229
class InvertedResidualConfig:

torchvision/models/quantization/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
2222
self.skip_mul = nn.quantized.FloatFunctional()
2323

2424
def forward(self, input: Tensor) -> Tensor:
25-
return self.skip_mul.mul(self._scale(input, False), input)
25+
return self.skip_mul.mul(self._scale(input), input)
2626

2727
def fuse_model(self) -> None:
2828
fuse_modules(self, ['fc1', 'relu'], inplace=True)

0 commit comments

Comments
 (0)