12
12
from torch .nn import functional as F
13
13
from torch .utils import model_zoo
14
14
15
-
16
15
########################################################################
17
16
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
18
17
########################################################################
24
23
'num_classes' , 'width_coefficient' , 'depth_coefficient' ,
25
24
'depth_divisor' , 'min_depth' , 'drop_connect_rate' , 'image_size' ])
26
25
27
-
28
26
# Parameters for an individual model block
29
27
BlockArgs = collections .namedtuple ('BlockArgs' , [
30
28
'kernel_size' , 'num_repeat' , 'input_filters' , 'output_filters' ,
31
29
'expand_ratio' , 'id_skip' , 'stride' , 'se_ratio' ])
32
30
33
-
34
31
# Change namedtuple defaults
35
32
GlobalParams .__new__ .__defaults__ = (None ,) * len (GlobalParams ._fields )
36
33
BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
37
34
38
35
39
- def relu_fn (x ):
40
- """ Swish activation function """
41
- return x * torch .sigmoid (x )
36
+ class SwishImplementation (torch .autograd .Function ):
37
+ @staticmethod
38
+ def forward (ctx , i ):
39
+ result = i * torch .sigmoid (i )
40
+ ctx .save_for_backward (i )
41
+ return result
42
+
43
+ @staticmethod
44
+ def backward (ctx , grad_output ):
45
+ i = ctx .saved_variables [0 ]
46
+ sigmoid_i = torch .sigmoid (i )
47
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i )))
48
+
49
+
50
+ class Swish (nn .Module ):
51
+ @staticmethod
52
+ def forward (x ):
53
+ return SwishImplementation .apply (x )
54
+
55
+
56
+ relu_fn = Swish ()
42
57
43
58
44
59
def round_filters (filters , global_params ):
@@ -84,11 +99,13 @@ def get_same_padding_conv2d(image_size=None):
84
99
else :
85
100
return partial (Conv2dStaticSamePadding , image_size = image_size )
86
101
102
+
87
103
class Conv2dDynamicSamePadding (nn .Conv2d ):
88
104
""" 2D Convolutions like TensorFlow, for a dynamic image size """
105
+
89
106
def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 , dilation = 1 , groups = 1 , bias = True ):
90
107
super ().__init__ (in_channels , out_channels , kernel_size , stride , 0 , dilation , groups , bias )
91
- self .stride = self .stride if len (self .stride ) == 2 else [self .stride [0 ]]* 2
108
+ self .stride = self .stride if len (self .stride ) == 2 else [self .stride [0 ]] * 2
92
109
93
110
def forward (self , x ):
94
111
ih , iw = x .size ()[- 2 :]
@@ -98,12 +115,13 @@ def forward(self, x):
98
115
pad_h = max ((oh - 1 ) * self .stride [0 ] + (kh - 1 ) * self .dilation [0 ] + 1 - ih , 0 )
99
116
pad_w = max ((ow - 1 ) * self .stride [1 ] + (kw - 1 ) * self .dilation [1 ] + 1 - iw , 0 )
100
117
if pad_h > 0 or pad_w > 0 :
101
- x = F .pad (x , [pad_w // 2 , pad_w - pad_w // 2 , pad_h // 2 , pad_h - pad_h // 2 ])
118
+ x = F .pad (x , [pad_w // 2 , pad_w - pad_w // 2 , pad_h // 2 , pad_h - pad_h // 2 ])
102
119
return F .conv2d (x , self .weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
103
120
104
121
105
122
class Conv2dStaticSamePadding (nn .Conv2d ):
106
123
""" 2D Convolutions like TensorFlow, for a fixed image size"""
124
+
107
125
def __init__ (self , in_channels , out_channels , kernel_size , image_size = None , ** kwargs ):
108
126
super ().__init__ (in_channels , out_channels , kernel_size , ** kwargs )
109
127
self .stride = self .stride if len (self .stride ) == 2 else [self .stride [0 ]] * 2
@@ -128,7 +146,7 @@ def forward(self, x):
128
146
129
147
130
148
class Identity (nn .Module ):
131
- def __init__ (self ,):
149
+ def __init__ (self , ):
132
150
super (Identity , self ).__init__ ()
133
151
134
152
def forward (self , input ):
@@ -286,6 +304,7 @@ def get_model_params(model_name, override_params):
286
304
'efficientnet-b7' : 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth' ,
287
305
}
288
306
307
+
289
308
def load_pretrained_weights (model , model_name , load_fc = True ):
290
309
""" Loads pretrained weights, and downloads if loading for the first time. """
291
310
state_dict = model_zoo .load_url (url_map [model_name ])
@@ -295,5 +314,5 @@ def load_pretrained_weights(model, model_name, load_fc=True):
295
314
state_dict .pop ('_fc.weight' )
296
315
state_dict .pop ('_fc.bias' )
297
316
res = model .load_state_dict (state_dict , strict = False )
298
- assert str (res .missing_keys ) == str (['_fc.weight' , '_fc.bias' ]), 'issue loading pretrained weights'
317
+ assert set (res .missing_keys ) == set (['_fc.weight' , '_fc.bias' ]), 'issue loading pretrained weights'
299
318
print ('Loaded pretrained weights for {}' .format (model_name ))
0 commit comments