Skip to content

Commit 7a4845a

Browse files
barrhfmassa
authored andcommitted
Add ShuffleNet v2 (#849)
* Add ShuffleNet v2 Added 4 configurations: x0.5, x1, x1.5, x2 Add 2 pretrained models: x0.5, x1 * fix lint * Change globalpool to torch.mean() call
1 parent e619613 commit 7a4845a

File tree

3 files changed

+191
-0
lines changed

3 files changed

+191
-0
lines changed

docs/source/models.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ architectures:
1111
- `DenseNet`_
1212
- `Inception`_ v3
1313
- `GoogLeNet`_
14+
- `ShuffleNet`_ v2
1415

1516
You can construct a model with random weights by calling its constructor:
1617

@@ -24,6 +25,7 @@ You can construct a model with random weights by calling its constructor:
2425
densenet = models.densenet161()
2526
inception = models.inception_v3()
2627
googlenet = models.googlenet()
28+
shufflenet = models.shufflenetv2()
2729
2830
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
2931
These can be constructed by passing ``pretrained=True``:
@@ -38,6 +40,7 @@ These can be constructed by passing ``pretrained=True``:
3840
densenet = models.densenet161(pretrained=True)
3941
inception = models.inception_v3(pretrained=True)
4042
googlenet = models.googlenet(pretrained=True)
43+
shufflenet = models.shufflenetv2(pretrained=True)
4144
4245
Instancing a pre-trained model will download its weights to a cache directory.
4346
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
@@ -88,6 +91,7 @@ Densenet-201 22.80 6.43
8891
Densenet-161 22.35 6.20
8992
Inception v3 22.55 6.44
9093
GoogleNet 30.22 10.47
94+
ShuffleNet V2 30.64 11.68
9195
================================ ============= =============
9296

9397

@@ -98,6 +102,7 @@ GoogleNet 30.22 10.47
98102
.. _DenseNet: https://arxiv.org/abs/1608.06993
99103
.. _Inception: https://arxiv.org/abs/1512.00567
100104
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
105+
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
101106

102107
.. currentmodule:: torchvision.models
103108

@@ -152,3 +157,8 @@ GoogLeNet
152157

153158
.. autofunction:: googlenet
154159

160+
ShuffleNet v2
161+
-------------
162+
163+
.. autofunction:: shufflenet
164+

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .densenet import *
77
from .googlenet import *
88
from .mobilenet import *
9+
from .shufflenetv2 import *

torchvision/models/shufflenetv2.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import functools
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
__all__ = ['ShuffleNetV2', 'shufflenetv2',
7+
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
8+
'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
9+
10+
model_urls = {
11+
'shufflenetv2_x0.5':
12+
'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x0.5-f707e7126e.pt',
13+
'shufflenetv2_x1.0':
14+
'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x1-5666bf0f80.pt',
15+
'shufflenetv2_x1.5': None,
16+
'shufflenetv2_x2.0': None,
17+
}
18+
19+
20+
def channel_shuffle(x, groups):
21+
batchsize, num_channels, height, width = x.data.size()
22+
channels_per_group = num_channels // groups
23+
24+
# reshape
25+
x = x.view(batchsize, groups,
26+
channels_per_group, height, width)
27+
28+
x = torch.transpose(x, 1, 2).contiguous()
29+
30+
# flatten
31+
x = x.view(batchsize, -1, height, width)
32+
33+
return x
34+
35+
36+
class InvertedResidual(nn.Module):
37+
def __init__(self, inp, oup, stride):
38+
super(InvertedResidual, self).__init__()
39+
40+
if not (1 <= stride <= 3):
41+
raise ValueError('illegal stride value')
42+
self.stride = stride
43+
44+
branch_features = oup // 2
45+
assert (self.stride != 1) or (inp == branch_features << 1)
46+
47+
pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False)
48+
dw_conv33 = functools.partial(self.depthwise_conv,
49+
kernel_size=3, stride=self.stride, padding=1)
50+
51+
if self.stride > 1:
52+
self.branch1 = nn.Sequential(
53+
dw_conv33(inp, inp),
54+
nn.BatchNorm2d(inp),
55+
pw_conv11(inp, branch_features),
56+
nn.BatchNorm2d(branch_features),
57+
nn.ReLU(inplace=True),
58+
)
59+
60+
self.branch2 = nn.Sequential(
61+
pw_conv11(inp if (self.stride > 1) else branch_features, branch_features),
62+
nn.BatchNorm2d(branch_features),
63+
nn.ReLU(inplace=True),
64+
dw_conv33(branch_features, branch_features),
65+
nn.BatchNorm2d(branch_features),
66+
pw_conv11(branch_features, branch_features),
67+
nn.BatchNorm2d(branch_features),
68+
nn.ReLU(inplace=True),
69+
)
70+
71+
@staticmethod
72+
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
73+
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
74+
75+
def forward(self, x):
76+
if self.stride == 1:
77+
x1, x2 = x.chunk(2, dim=1)
78+
out = torch.cat((x1, self.branch2(x2)), dim=1)
79+
else:
80+
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
81+
82+
out = channel_shuffle(out, 2)
83+
84+
return out
85+
86+
87+
class ShuffleNetV2(nn.Module):
88+
def __init__(self, num_classes=1000, input_size=224, width_mult=1):
89+
super(ShuffleNetV2, self).__init__()
90+
91+
try:
92+
self.stage_out_channels = self._getStages(float(width_mult))
93+
except KeyError:
94+
raise ValueError('width_mult {} is not supported'.format(width_mult))
95+
96+
input_channels = 3
97+
output_channels = self.stage_out_channels[0]
98+
self.conv1 = nn.Sequential(
99+
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
100+
nn.BatchNorm2d(output_channels),
101+
nn.ReLU(inplace=True),
102+
)
103+
input_channels = output_channels
104+
105+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106+
107+
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
108+
stage_repeats = [4, 8, 4]
109+
for name, repeats, output_channels in zip(
110+
stage_names, stage_repeats, self.stage_out_channels[1:]):
111+
seq = [InvertedResidual(input_channels, output_channels, 2)]
112+
for i in range(repeats - 1):
113+
seq.append(InvertedResidual(output_channels, output_channels, 1))
114+
setattr(self, name, nn.Sequential(*seq))
115+
input_channels = output_channels
116+
117+
output_channels = self.stage_out_channels[-1]
118+
self.conv5 = nn.Sequential(
119+
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
120+
nn.BatchNorm2d(output_channels),
121+
nn.ReLU(inplace=True),
122+
)
123+
124+
self.fc = nn.Linear(output_channels, num_classes)
125+
126+
def forward(self, x):
127+
x = self.conv1(x)
128+
x = self.maxpool(x)
129+
x = self.stage2(x)
130+
x = self.stage3(x)
131+
x = self.stage4(x)
132+
x = self.conv5(x)
133+
x = x.mean([2, 3]) # globalpool
134+
x = self.fc(x)
135+
return x
136+
137+
@staticmethod
138+
def _getStages(mult):
139+
stages = {
140+
'0.5': [24, 48, 96, 192, 1024],
141+
'1.0': [24, 116, 232, 464, 1024],
142+
'1.5': [24, 176, 352, 704, 1024],
143+
'2.0': [24, 244, 488, 976, 2048],
144+
}
145+
return stages[str(mult)]
146+
147+
148+
def shufflenetv2(pretrained=False, num_classes=1000, input_size=224, width_mult=1, **kwargs):
149+
model = ShuffleNetV2(num_classes=num_classes, input_size=input_size, width_mult=width_mult)
150+
151+
if pretrained:
152+
# change width_mult to float
153+
if isinstance(width_mult, int):
154+
width_mult = float(width_mult)
155+
model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)]))
156+
try:
157+
model_url = model_urls[model_type.lower()]
158+
except KeyError:
159+
raise ValueError('model {} is not support'.format(model_type))
160+
if model_url is None:
161+
raise NotImplementedError('pretrained {} is not supported'.format(model_type))
162+
model.load_state_dict(torch.utils.model_zoo.load_url(model_url))
163+
164+
return model
165+
166+
167+
def shufflenetv2_x0_5(pretrained=False, num_classes=1000, input_size=224, **kwargs):
168+
return shufflenetv2(pretrained, num_classes, input_size, 0.5)
169+
170+
171+
def shufflenetv2_x1_0(pretrained=False, num_classes=1000, input_size=224, **kwargs):
172+
return shufflenetv2(pretrained, num_classes, input_size, 1)
173+
174+
175+
def shufflenetv2_x1_5(pretrained=False, num_classes=1000, input_size=224, **kwargs):
176+
return shufflenetv2(pretrained, num_classes, input_size, 1.5)
177+
178+
179+
def shufflenetv2_x2_0(pretrained=False, num_classes=1000, input_size=224, **kwargs):
180+
return shufflenetv2(pretrained, num_classes, input_size, 2)

0 commit comments

Comments
 (0)