Skip to content

Commit 532a6d7

Browse files
authored
Merge pull request #87 from lukemelas/relu_update
Updated ReLU, dropout, and more
2 parents de40cbf + 3be143e commit 532a6d7

File tree

6 files changed

+159
-17
lines changed

6 files changed

+159
-17
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# EfficientNet PyTorch
22

3+
### Update (October 12, 2019)
4+
5+
This update changes activation function implementation to more memory-efficient. For more details please refer to: https://github.com/lukemelas/EfficientNet-PyTorch/issues/18. Thanks to [Dmytro Panchenko](https://www.kaggle.com/hokmund) for the pull request.
6+
37
### Update (July 31, 2019)
48

59
_Upgrade the pip package with_ `pip install --upgrade efficientnet-pytorch`

efficientnet_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.4.0"
1+
__version__ = "0.5.0"
22
from .model import EfficientNet
33
from .utils import (
44
GlobalParams,

efficientnet_pytorch/model.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def __init__(self, blocks_args=None, global_params=None):
150150
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
151151

152152
# Final linear layer
153-
self._dropout = self._global_params.dropout_rate
153+
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
154+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
154155
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
155156

156157
def extract_features(self, inputs):
@@ -173,14 +174,14 @@ def extract_features(self, inputs):
173174

174175
def forward(self, inputs):
175176
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """
176-
177+
bs = inputs.size(0)
177178
# Convolution layers
178179
x = self.extract_features(inputs)
179180

180181
# Pooling and final linear layer
181-
x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)
182-
if self._dropout:
183-
x = F.dropout(x, p=self._dropout, training=self.training)
182+
x = self._avg_pooling(x)
183+
x = x.view(bs, -1)
184+
x = self._dropout(x)
184185
x = self._fc(x)
185186
return x
186187

@@ -190,10 +191,21 @@ def from_name(cls, model_name, override_params=None):
190191
blocks_args, global_params = get_model_params(model_name, override_params)
191192
return cls(blocks_args, global_params)
192193

194+
@classmethod
195+
def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3):
196+
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
197+
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
198+
if in_channels != 3:
199+
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
200+
out_channels = round_filters(32, model._global_params)
201+
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
202+
return model
203+
193204
@classmethod
194205
def from_pretrained(cls, model_name, num_classes=1000):
195206
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
196207
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
208+
197209
return model
198210

199211
@classmethod

efficientnet_pytorch/utils.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch.nn import functional as F
1313
from torch.utils import model_zoo
1414

15-
1615
########################################################################
1716
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
1817
########################################################################
@@ -24,21 +23,37 @@
2423
'num_classes', 'width_coefficient', 'depth_coefficient',
2524
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
2625

27-
2826
# Parameters for an individual model block
2927
BlockArgs = collections.namedtuple('BlockArgs', [
3028
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
3129
'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
3230

33-
3431
# Change namedtuple defaults
3532
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
3633
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
3734

3835

39-
def relu_fn(x):
40-
""" Swish activation function """
41-
return x * torch.sigmoid(x)
36+
class SwishImplementation(torch.autograd.Function):
37+
@staticmethod
38+
def forward(ctx, i):
39+
result = i * torch.sigmoid(i)
40+
ctx.save_for_backward(i)
41+
return result
42+
43+
@staticmethod
44+
def backward(ctx, grad_output):
45+
i = ctx.saved_variables[0]
46+
sigmoid_i = torch.sigmoid(i)
47+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
48+
49+
50+
class Swish(nn.Module):
51+
@staticmethod
52+
def forward(x):
53+
return SwishImplementation.apply(x)
54+
55+
56+
relu_fn = Swish()
4257

4358

4459
def round_filters(filters, global_params):
@@ -84,11 +99,13 @@ def get_same_padding_conv2d(image_size=None):
8499
else:
85100
return partial(Conv2dStaticSamePadding, image_size=image_size)
86101

102+
87103
class Conv2dDynamicSamePadding(nn.Conv2d):
88104
""" 2D Convolutions like TensorFlow, for a dynamic image size """
105+
89106
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
90107
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
91-
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2
108+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
92109

93110
def forward(self, x):
94111
ih, iw = x.size()[-2:]
@@ -98,12 +115,13 @@ def forward(self, x):
98115
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
99116
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
100117
if pad_h > 0 or pad_w > 0:
101-
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
118+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
102119
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
103120

104121

105122
class Conv2dStaticSamePadding(nn.Conv2d):
106123
""" 2D Convolutions like TensorFlow, for a fixed image size"""
124+
107125
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
108126
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
109127
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
@@ -128,7 +146,7 @@ def forward(self, x):
128146

129147

130148
class Identity(nn.Module):
131-
def __init__(self,):
149+
def __init__(self, ):
132150
super(Identity, self).__init__()
133151

134152
def forward(self, input):
@@ -286,6 +304,7 @@ def get_model_params(model_name, override_params):
286304
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
287305
}
288306

307+
289308
def load_pretrained_weights(model, model_name, load_fc=True):
290309
""" Loads pretrained weights, and downloads if loading for the first time. """
291310
state_dict = model_zoo.load_url(url_map[model_name])
@@ -295,5 +314,5 @@ def load_pretrained_weights(model, model_name, load_fc=True):
295314
state_dict.pop('_fc.weight')
296315
state_dict.pop('_fc.bias')
297316
res = model.load_state_dict(state_dict, strict=False)
298-
assert str(res.missing_keys) == str(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
317+
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
299318
print('Loaded pretrained weights for {}'.format(model_name))

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
1919
AUTHOR = 'Luke'
2020
REQUIRES_PYTHON = '>=3.5.0'
21-
VERSION = '0.4.0'
21+
VERSION = '0.5.0'
2222

2323
# What packages are required for this module to be executed?
2424
REQUIRED = [

tests/test_model.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from collections import OrderedDict
2+
3+
import pytest
4+
import torch
5+
import torch.nn as nn
6+
7+
from efficientnet_pytorch import EfficientNet
8+
9+
10+
# -- fixtures -------------------------------------------------------------------------------------
11+
12+
@pytest.fixture(scope='module', params=[x for x in range(4)])
13+
def model(request):
14+
return 'efficientnet-b{}'.format(request.param)
15+
16+
17+
@pytest.fixture(scope='module', params=[True, False])
18+
def pretrained(request):
19+
return request.param
20+
21+
22+
@pytest.fixture(scope='function')
23+
def net(model, pretrained):
24+
return EfficientNet.from_pretrained(model) if pretrained else EfficientNet.from_name(model)
25+
26+
27+
# -- tests ----------------------------------------------------------------------------------------
28+
29+
@pytest.mark.parametrize('img_size', [224, 256, 512])
30+
def test_forward(net, img_size):
31+
"""Test `.forward()` doesn't throw an error"""
32+
data = torch.zeros((1, 3, img_size, img_size))
33+
output = net(data)
34+
assert not torch.isnan(output).any()
35+
36+
37+
def test_dropout_training(net):
38+
"""Test dropout `.training` is set by `.train()` on parent `nn.module`"""
39+
net.train()
40+
assert net._dropout.training == True
41+
42+
43+
def test_dropout_eval(net):
44+
"""Test dropout `.training` is set by `.eval()` on parent `nn.module`"""
45+
net.eval()
46+
assert net._dropout.training == False
47+
48+
49+
def test_dropout_update(net):
50+
"""Test dropout `.training` is updated by `.train()` and `.eval()` on parent `nn.module`"""
51+
net.train()
52+
assert net._dropout.training == True
53+
net.eval()
54+
assert net._dropout.training == False
55+
net.train()
56+
assert net._dropout.training == True
57+
net.eval()
58+
assert net._dropout.training == False
59+
60+
61+
@pytest.mark.parametrize('img_size', [224, 256, 512])
62+
def test_modify_dropout(net, img_size):
63+
"""Test ability to modify dropout and fc modules of network"""
64+
dropout = nn.Sequential(OrderedDict([
65+
('_bn2', nn.BatchNorm1d(net._bn1.num_features)),
66+
('_drop1', nn.Dropout(p=net._global_params.dropout_rate)),
67+
('_linear1', nn.Linear(net._bn1.num_features, 512)),
68+
('_relu', nn.ReLU()),
69+
('_bn3', nn.BatchNorm1d(512)),
70+
('_drop2', nn.Dropout(p=net._global_params.dropout_rate / 2))
71+
]))
72+
fc = nn.Linear(512, net._global_params.num_classes)
73+
74+
net._dropout = dropout
75+
net._fc = fc
76+
77+
data = torch.zeros((2, 3, img_size, img_size))
78+
output = net(data)
79+
assert not torch.isnan(output).any()
80+
81+
82+
@pytest.mark.parametrize('img_size', [224, 256, 512])
83+
def test_modify_pool(net, img_size):
84+
"""Test ability to modify pooling module of network"""
85+
86+
class AdaptiveMaxAvgPool(nn.Module):
87+
88+
def __init__(self):
89+
super().__init__()
90+
self.ada_avgpool = nn.AdaptiveAvgPool2d(1)
91+
self.ada_maxpool = nn.AdaptiveMaxPool2d(1)
92+
93+
def forward(self, x):
94+
avg_x = self.ada_avgpool(x)
95+
max_x = self.ada_maxpool(x)
96+
x = torch.cat((avg_x, max_x), dim=1)
97+
return x
98+
99+
avg_pooling = AdaptiveMaxAvgPool()
100+
fc = nn.Linear(net._fc.in_features * 2, net._global_params.num_classes)
101+
102+
net._avg_pooling = avg_pooling
103+
net._fc = fc
104+
105+
data = torch.zeros((2, 3, img_size, img_size))
106+
output = net(data)
107+
assert not torch.isnan(output).any()

0 commit comments

Comments
 (0)