Skip to content

Commit 2eb7a7d

Browse files
authored
Merge pull request #214 from chrisyeh96/constant_model_names
Move valid model names to constant
2 parents a27375b + d28f390 commit 2eb7a7d

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

efficientnet_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__version__ = "0.7.0"
2-
from .model import EfficientNet
2+
from .model import EfficientNet, VALID_MODELS
33
from .utils import (
44
GlobalParams,
55
BlockArgs,

efficientnet_pytorch/model.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@
2222
calculate_output_image_size
2323
)
2424

25+
26+
VALID_MODELS = (
27+
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
28+
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
29+
'efficientnet-b8',
30+
31+
# Support the construction of 'efficientnet-l2' without pretrained weights
32+
'efficientnet-l2'
33+
)
34+
35+
2536
class MBConvBlock(nn.Module):
2637
"""Mobile Inverted Residual Bottleneck Block.
2738
@@ -388,14 +399,9 @@ def _check_model_name_is_valid(cls, model_name):
388399
Returns:
389400
bool: Is a valid name or not.
390401
"""
391-
valid_models = ['efficientnet-b'+str(i) for i in range(9)]
392-
393-
# Support the construction of 'efficientnet-l2' without pretrained weights
394-
valid_models += ['efficientnet-l2']
402+
if model_name not in VALID_MODELS:
403+
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
395404

396-
if model_name not in valid_models:
397-
raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
398-
399405
def _change_in_channels(self, in_channels):
400406
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
401407

0 commit comments

Comments
 (0)