diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py index 3a42217..5021b15 100644 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -150,7 +150,8 @@ def __init__(self, blocks_args=None, global_params=None): self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) # Final linear layer - self._dropout = self._global_params.dropout_rate + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + self._dropout = nn.Dropout(self._global_params.dropout_rate) self._fc = nn.Linear(out_channels, self._global_params.num_classes) def extract_features(self, inputs): @@ -173,14 +174,14 @@ def extract_features(self, inputs): def forward(self, inputs): """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ - + bs = inputs.size(0) # Convolution layers x = self.extract_features(inputs) # Pooling and final linear layer - x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) - if self._dropout: - x = F.dropout(x, p=self._dropout, training=self.training) + x = self._avg_pooling(x) + x = x.view(bs, -1) + x = self._dropout(x) x = self._fc(x) return x diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..a138a2b --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,107 @@ +from collections import OrderedDict + +import pytest +import torch +import torch.nn as nn + +from efficientnet_pytorch import EfficientNet + + +# -- fixtures ------------------------------------------------------------------------------------- + +@pytest.fixture(scope='module', params=[x for x in range(4)]) +def model(request): + return 'efficientnet-b{}'.format(request.param) + + +@pytest.fixture(scope='module', params=[True, False]) +def pretrained(request): + return request.param + + +@pytest.fixture(scope='function') +def net(model, pretrained): + return EfficientNet.from_pretrained(model) if pretrained else EfficientNet.from_name(model) + + +# -- tests ---------------------------------------------------------------------------------------- + +@pytest.mark.parametrize('img_size', [224, 256, 512]) +def test_forward(net, img_size): + """Test `.forward()` doesn't throw an error""" + data = torch.zeros((1, 3, img_size, img_size)) + output = net(data) + assert not torch.isnan(output).any() + + +def test_dropout_training(net): + """Test dropout `.training` is set by `.train()` on parent `nn.module`""" + net.train() + assert net._dropout.training == True + + +def test_dropout_eval(net): + """Test dropout `.training` is set by `.eval()` on parent `nn.module`""" + net.eval() + assert net._dropout.training == False + + +def test_dropout_update(net): + """Test dropout `.training` is updated by `.train()` and `.eval()` on parent `nn.module`""" + net.train() + assert net._dropout.training == True + net.eval() + assert net._dropout.training == False + net.train() + assert net._dropout.training == True + net.eval() + assert net._dropout.training == False + + +@pytest.mark.parametrize('img_size', [224, 256, 512]) +def test_modify_dropout(net, img_size): + """Test ability to modify dropout and fc modules of network""" + dropout = nn.Sequential(OrderedDict([ + ('_bn2', nn.BatchNorm1d(net._bn1.num_features)), + ('_drop1', nn.Dropout(p=net._global_params.dropout_rate)), + ('_linear1', nn.Linear(net._bn1.num_features, 512)), + ('_relu', nn.ReLU()), + ('_bn3', nn.BatchNorm1d(512)), + ('_drop2', nn.Dropout(p=net._global_params.dropout_rate / 2)) + ])) + fc = nn.Linear(512, net._global_params.num_classes) + + net._dropout = dropout + net._fc = fc + + data = torch.zeros((2, 3, img_size, img_size)) + output = net(data) + assert not torch.isnan(output).any() + + +@pytest.mark.parametrize('img_size', [224, 256, 512]) +def test_modify_pool(net, img_size): + """Test ability to modify pooling module of network""" + + class AdaptiveMaxAvgPool(nn.Module): + + def __init__(self): + super().__init__() + self.ada_avgpool = nn.AdaptiveAvgPool2d(1) + self.ada_maxpool = nn.AdaptiveMaxPool2d(1) + + def forward(self, x): + avg_x = self.ada_avgpool(x) + max_x = self.ada_maxpool(x) + x = torch.cat((avg_x, max_x), dim=1) + return x + + avg_pooling = AdaptiveMaxAvgPool() + fc = nn.Linear(net._fc.in_features * 2, net._global_params.num_classes) + + net._avg_pooling = avg_pooling + net._fc = fc + + data = torch.zeros((2, 3, img_size, img_size)) + output = net(data) + assert not torch.isnan(output).any()