diff --git a/README.md b/README.md index 46dbf6a..4dfdd8b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # EfficientNet PyTorch +### Update (October 08, 2019) + +This update changes activation function implementation to more memory-efficient. For more details please refer to: https://github.com/lukemelas/EfficientNet-PyTorch/issues/18 + ### Update (July 31, 2019) _Upgrade the pip package with_ `pip install --upgrade efficientnet-pytorch` diff --git a/efficientnet_pytorch/utils.py b/efficientnet_pytorch/utils.py index acdfb77..b8294c9 100644 --- a/efficientnet_pytorch/utils.py +++ b/efficientnet_pytorch/utils.py @@ -12,7 +12,6 @@ from torch.nn import functional as F from torch.utils import model_zoo - ######################################################################## ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### ######################################################################## @@ -24,21 +23,37 @@ 'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) - # Parameters for an individual model block BlockArgs = collections.namedtuple('BlockArgs', [ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) - # Change namedtuple defaults GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) -def relu_fn(x): - """ Swish activation function """ - return x * torch.sigmoid(x) +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class Swish(nn.Module): + @staticmethod + def forward(x): + return SwishImplementation.apply(x) + + +relu_fn = Swish() def round_filters(filters, global_params): @@ -84,11 +99,13 @@ def get_same_padding_conv2d(image_size=None): else: return partial(Conv2dStaticSamePadding, image_size=image_size) + class Conv2dDynamicSamePadding(nn.Conv2d): """ 2D Convolutions like TensorFlow, for a dynamic image size """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) - self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2 + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 def forward(self, x): ih, iw = x.size()[-2:] @@ -98,12 +115,13 @@ def forward(self, x): pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) class Conv2dStaticSamePadding(nn.Conv2d): """ 2D Convolutions like TensorFlow, for a fixed image size""" + def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): super().__init__(in_channels, out_channels, kernel_size, **kwargs) self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 @@ -128,7 +146,7 @@ def forward(self, x): class Identity(nn.Module): - def __init__(self,): + def __init__(self, ): super(Identity, self).__init__() def forward(self, input): @@ -286,6 +304,7 @@ def get_model_params(model_name, override_params): 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth', } + def load_pretrained_weights(model, model_name, load_fc=True): """ Loads pretrained weights, and downloads if loading for the first time. """ state_dict = model_zoo.load_url(url_map[model_name])