Skip to content

Commit 8330fab

Browse files
committed
Adding B5-B7 weights.
1 parent d2bfd63 commit 8330fab

File tree

5 files changed

+48
-2
lines changed

5 files changed

+48
-2
lines changed

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def load_data(traindir, valdir, args):
9090
elif args.model.startswith('efficientnet_'):
9191
sizes = {
9292
'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300),
93-
'b4': (384, 380), 'b5': (489, 456), 'b6': (561, 528), 'b7': (633, 600),
93+
'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600),
9494
}
9595
e_type = args.model.replace('efficientnet_', '')
9696
resize_size, crop_size = sizes[e_type]
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.

torchvision/models/efficientnet.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3",
18-
"efficientnet_b4"]
18+
"efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"]
1919

2020

2121
model_urls = {
@@ -25,6 +25,10 @@
2525
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2-rwightman.pth",
2626
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3-rwightman.pth",
2727
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4-rwightman.pth",
28+
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
29+
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5-lukemelas.pth",
30+
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6-lukemelas.pth",
31+
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7-lukemelas.pth",
2832
}
2933

3034

@@ -322,3 +326,45 @@ def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: A
322326
"""
323327
inverted_residual_setting = _efficientnet_conf(width_mult=1.4, depth_mult=1.8, **kwargs)
324328
return _efficientnet_model("efficientnet_b4", inverted_residual_setting, 0.4, pretrained, progress, **kwargs)
329+
330+
331+
def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
332+
"""
333+
Constructs a EfficientNet B5 architecture from
334+
`"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_.
335+
336+
Args:
337+
pretrained (bool): If True, returns a model pre-trained on ImageNet
338+
progress (bool): If True, displays a progress bar of the download to stderr
339+
"""
340+
inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs)
341+
return _efficientnet_model("efficientnet_b5", inverted_residual_setting, 0.4, pretrained, progress,
342+
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs)
343+
344+
345+
def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
346+
"""
347+
Constructs a EfficientNet B6 architecture from
348+
`"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_.
349+
350+
Args:
351+
pretrained (bool): If True, returns a model pre-trained on ImageNet
352+
progress (bool): If True, displays a progress bar of the download to stderr
353+
"""
354+
inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs)
355+
return _efficientnet_model("efficientnet_b6", inverted_residual_setting, 0.5, pretrained, progress,
356+
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs)
357+
358+
359+
def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
360+
"""
361+
Constructs a EfficientNet B7 architecture from
362+
`"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_.
363+
364+
Args:
365+
pretrained (bool): If True, returns a model pre-trained on ImageNet
366+
progress (bool): If True, displays a progress bar of the download to stderr
367+
"""
368+
inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs)
369+
return _efficientnet_model("efficientnet_b7", inverted_residual_setting, 0.5, pretrained, progress,
370+
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs)

0 commit comments

Comments
 (0)