|
6 | 6 | from typing import Any, Callable, Dict, List, Optional, Sequence |
7 | 7 |
|
8 | 8 | from .._internally_replaced_utils import load_state_dict_from_url |
| 9 | +from torchvision.models.efficientnet import SqueezeExcitation as SElayer |
9 | 10 | from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation |
10 | 11 |
|
11 | 12 |
|
|
18 | 19 | } |
19 | 20 |
|
20 | 21 |
|
21 | | -class SqueezeExcitation(nn.Module): |
22 | | - # Implemented as described at Figure 4 of the MobileNetV3 paper |
| 22 | +class SqueezeExcitation(SElayer): |
23 | 23 | def __init__(self, input_channels: int, squeeze_factor: int = 4): |
24 | | - super().__init__() |
25 | 24 | squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) |
26 | | - self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) |
27 | | - self.relu = nn.ReLU(inplace=True) |
28 | | - self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) |
29 | | - |
30 | | - def _scale(self, input: Tensor, inplace: bool) -> Tensor: |
31 | | - scale = F.adaptive_avg_pool2d(input, 1) |
32 | | - scale = self.fc1(scale) |
33 | | - scale = self.relu(scale) |
34 | | - scale = self.fc2(scale) |
35 | | - return F.hardsigmoid(scale, inplace=inplace) |
36 | | - |
37 | | - def forward(self, input: Tensor) -> Tensor: |
38 | | - scale = self._scale(input, True) |
39 | | - return scale * input |
| 25 | + super().__init__(input_channels, squeeze_channels, activation=nn.ReLU, scale_activation=nn.Hardsigmoid) |
| 26 | + self.relu = self.activation |
40 | 27 |
|
41 | 28 |
|
42 | 29 | class InvertedResidualConfig: |
|
0 commit comments