|  | 
|  | 1 | +import functools | 
|  | 2 | + | 
|  | 3 | +import torch | 
|  | 4 | +import torch.nn as nn | 
|  | 5 | + | 
|  | 6 | +__all__ = ['ShuffleNetV2', 'shufflenetv2', | 
|  | 7 | +           'shufflenetv2_x0_5', 'shufflenetv2_x1_0', | 
|  | 8 | +           'shufflenetv2_x1_5', 'shufflenetv2_x2_0'] | 
|  | 9 | + | 
|  | 10 | +model_urls = { | 
|  | 11 | +    'shufflenetv2_x0.5': | 
|  | 12 | +        'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x0.5-f707e7126e.pt', | 
|  | 13 | +    'shufflenetv2_x1.0': | 
|  | 14 | +        'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x1-5666bf0f80.pt', | 
|  | 15 | +    'shufflenetv2_x1.5': None, | 
|  | 16 | +    'shufflenetv2_x2.0': None, | 
|  | 17 | +} | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +def channel_shuffle(x, groups): | 
|  | 21 | +    batchsize, num_channels, height, width = x.data.size() | 
|  | 22 | +    channels_per_group = num_channels // groups | 
|  | 23 | + | 
|  | 24 | +    # reshape | 
|  | 25 | +    x = x.view(batchsize, groups, | 
|  | 26 | +               channels_per_group, height, width) | 
|  | 27 | + | 
|  | 28 | +    x = torch.transpose(x, 1, 2).contiguous() | 
|  | 29 | + | 
|  | 30 | +    # flatten | 
|  | 31 | +    x = x.view(batchsize, -1, height, width) | 
|  | 32 | + | 
|  | 33 | +    return x | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +class InvertedResidual(nn.Module): | 
|  | 37 | +    def __init__(self, inp, oup, stride): | 
|  | 38 | +        super(InvertedResidual, self).__init__() | 
|  | 39 | + | 
|  | 40 | +        if not (1 <= stride <= 3): | 
|  | 41 | +            raise ValueError('illegal stride value') | 
|  | 42 | +        self.stride = stride | 
|  | 43 | + | 
|  | 44 | +        branch_features = oup // 2 | 
|  | 45 | +        assert (self.stride != 1) or (inp == branch_features << 1) | 
|  | 46 | + | 
|  | 47 | +        pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False) | 
|  | 48 | +        dw_conv33 = functools.partial(self.depthwise_conv, | 
|  | 49 | +                                      kernel_size=3, stride=self.stride, padding=1) | 
|  | 50 | + | 
|  | 51 | +        if self.stride > 1: | 
|  | 52 | +            self.branch1 = nn.Sequential( | 
|  | 53 | +                dw_conv33(inp, inp), | 
|  | 54 | +                nn.BatchNorm2d(inp), | 
|  | 55 | +                pw_conv11(inp, branch_features), | 
|  | 56 | +                nn.BatchNorm2d(branch_features), | 
|  | 57 | +                nn.ReLU(inplace=True), | 
|  | 58 | +            ) | 
|  | 59 | + | 
|  | 60 | +        self.branch2 = nn.Sequential( | 
|  | 61 | +            pw_conv11(inp if (self.stride > 1) else branch_features, branch_features), | 
|  | 62 | +            nn.BatchNorm2d(branch_features), | 
|  | 63 | +            nn.ReLU(inplace=True), | 
|  | 64 | +            dw_conv33(branch_features, branch_features), | 
|  | 65 | +            nn.BatchNorm2d(branch_features), | 
|  | 66 | +            pw_conv11(branch_features, branch_features), | 
|  | 67 | +            nn.BatchNorm2d(branch_features), | 
|  | 68 | +            nn.ReLU(inplace=True), | 
|  | 69 | +        ) | 
|  | 70 | + | 
|  | 71 | +    @staticmethod | 
|  | 72 | +    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): | 
|  | 73 | +        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) | 
|  | 74 | + | 
|  | 75 | +    def forward(self, x): | 
|  | 76 | +        if self.stride == 1: | 
|  | 77 | +            x1, x2 = x.chunk(2, dim=1) | 
|  | 78 | +            out = torch.cat((x1, self.branch2(x2)), dim=1) | 
|  | 79 | +        else: | 
|  | 80 | +            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) | 
|  | 81 | + | 
|  | 82 | +        out = channel_shuffle(out, 2) | 
|  | 83 | + | 
|  | 84 | +        return out | 
|  | 85 | + | 
|  | 86 | + | 
|  | 87 | +class ShuffleNetV2(nn.Module): | 
|  | 88 | +    def __init__(self, num_classes=1000, input_size=224, width_mult=1): | 
|  | 89 | +        super(ShuffleNetV2, self).__init__() | 
|  | 90 | + | 
|  | 91 | +        try: | 
|  | 92 | +            self.stage_out_channels = self._getStages(float(width_mult)) | 
|  | 93 | +        except KeyError: | 
|  | 94 | +            raise ValueError('width_mult {} is not supported'.format(width_mult)) | 
|  | 95 | + | 
|  | 96 | +        input_channels = 3 | 
|  | 97 | +        output_channels = self.stage_out_channels[0] | 
|  | 98 | +        self.conv1 = nn.Sequential( | 
|  | 99 | +            nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), | 
|  | 100 | +            nn.BatchNorm2d(output_channels), | 
|  | 101 | +            nn.ReLU(inplace=True), | 
|  | 102 | +        ) | 
|  | 103 | +        input_channels = output_channels | 
|  | 104 | + | 
|  | 105 | +        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | 
|  | 106 | + | 
|  | 107 | +        stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] | 
|  | 108 | +        stage_repeats = [4, 8, 4] | 
|  | 109 | +        for name, repeats, output_channels in zip( | 
|  | 110 | +                stage_names, stage_repeats, self.stage_out_channels[1:]): | 
|  | 111 | +            seq = [InvertedResidual(input_channels, output_channels, 2)] | 
|  | 112 | +            for i in range(repeats - 1): | 
|  | 113 | +                seq.append(InvertedResidual(output_channels, output_channels, 1)) | 
|  | 114 | +            setattr(self, name, nn.Sequential(*seq)) | 
|  | 115 | +            input_channels = output_channels | 
|  | 116 | + | 
|  | 117 | +        output_channels = self.stage_out_channels[-1] | 
|  | 118 | +        self.conv5 = nn.Sequential( | 
|  | 119 | +            nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), | 
|  | 120 | +            nn.BatchNorm2d(output_channels), | 
|  | 121 | +            nn.ReLU(inplace=True), | 
|  | 122 | +        ) | 
|  | 123 | + | 
|  | 124 | +        self.fc = nn.Linear(output_channels, num_classes) | 
|  | 125 | + | 
|  | 126 | +    def forward(self, x): | 
|  | 127 | +        x = self.conv1(x) | 
|  | 128 | +        x = self.maxpool(x) | 
|  | 129 | +        x = self.stage2(x) | 
|  | 130 | +        x = self.stage3(x) | 
|  | 131 | +        x = self.stage4(x) | 
|  | 132 | +        x = self.conv5(x) | 
|  | 133 | +        x = x.mean([2, 3])  # globalpool | 
|  | 134 | +        x = self.fc(x) | 
|  | 135 | +        return x | 
|  | 136 | + | 
|  | 137 | +    @staticmethod | 
|  | 138 | +    def _getStages(mult): | 
|  | 139 | +        stages = { | 
|  | 140 | +            '0.5': [24, 48, 96, 192, 1024], | 
|  | 141 | +            '1.0': [24, 116, 232, 464, 1024], | 
|  | 142 | +            '1.5': [24, 176, 352, 704, 1024], | 
|  | 143 | +            '2.0': [24, 244, 488, 976, 2048], | 
|  | 144 | +        } | 
|  | 145 | +        return stages[str(mult)] | 
|  | 146 | + | 
|  | 147 | + | 
|  | 148 | +def shufflenetv2(pretrained=False, num_classes=1000, input_size=224, width_mult=1, **kwargs): | 
|  | 149 | +    model = ShuffleNetV2(num_classes=num_classes, input_size=input_size, width_mult=width_mult) | 
|  | 150 | + | 
|  | 151 | +    if pretrained: | 
|  | 152 | +        # change width_mult to float | 
|  | 153 | +        if isinstance(width_mult, int): | 
|  | 154 | +            width_mult = float(width_mult) | 
|  | 155 | +        model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)])) | 
|  | 156 | +        try: | 
|  | 157 | +            model_url = model_urls[model_type.lower()] | 
|  | 158 | +        except KeyError: | 
|  | 159 | +            raise ValueError('model {} is not support'.format(model_type)) | 
|  | 160 | +        if model_url is None: | 
|  | 161 | +            raise NotImplementedError('pretrained {} is not supported'.format(model_type)) | 
|  | 162 | +        model.load_state_dict(torch.utils.model_zoo.load_url(model_url)) | 
|  | 163 | + | 
|  | 164 | +    return model | 
|  | 165 | + | 
|  | 166 | + | 
|  | 167 | +def shufflenetv2_x0_5(pretrained=False, num_classes=1000, input_size=224, **kwargs): | 
|  | 168 | +    return shufflenetv2(pretrained, num_classes, input_size, 0.5) | 
|  | 169 | + | 
|  | 170 | + | 
|  | 171 | +def shufflenetv2_x1_0(pretrained=False, num_classes=1000, input_size=224, **kwargs): | 
|  | 172 | +    return shufflenetv2(pretrained, num_classes, input_size, 1) | 
|  | 173 | + | 
|  | 174 | + | 
|  | 175 | +def shufflenetv2_x1_5(pretrained=False, num_classes=1000, input_size=224, **kwargs): | 
|  | 176 | +    return shufflenetv2(pretrained, num_classes, input_size, 1.5) | 
|  | 177 | + | 
|  | 178 | + | 
|  | 179 | +def shufflenetv2_x2_0(pretrained=False, num_classes=1000, input_size=224, **kwargs): | 
|  | 180 | +    return shufflenetv2(pretrained, num_classes, input_size, 2) | 
0 commit comments