|
3 | 3 | import torch |
4 | 4 | import torch.nn as nn |
5 | 5 | import torch.utils.model_zoo as model_zoo |
| 6 | +from .resnext_features import resnext50_32x4d_features |
6 | 7 | from .resnext_features import resnext101_32x4d_features |
7 | 8 | from .resnext_features import resnext101_64x4d_features |
8 | 9 |
|
9 | | -__all__ = ['ResNeXt101_32x4d', 'resnext101_32x4d', |
| 10 | +__all__ = ['ResNeXt50_32x4d', 'resnext50_32x4d', |
| 11 | + 'ResNeXt101_32x4d', 'resnext101_32x4d', |
10 | 12 | 'ResNeXt101_64x4d', 'resnext101_64x4d'] |
11 | 13 |
|
12 | 14 | pretrained_settings = { |
| 15 | + 'resnext50_32x4d': { |
| 16 | + 'imagenet': { |
| 17 | + 'url': 'file:/data/resnext50_32x4d-dc76b0bd094076dae.pth', # TODO (barrh): upload model, |
| 18 | + 'input_space': 'RGB', |
| 19 | + 'input_size': [3, 224, 224], |
| 20 | + 'input_range': [0, 1], |
| 21 | + 'mean': [0.485, 0.456, 0.406], |
| 22 | + 'std': [0.229, 0.224, 0.225], |
| 23 | + 'num_classes': 1000 |
| 24 | + } |
| 25 | + }, |
13 | 26 | 'resnext101_32x4d': { |
14 | 27 | 'imagenet': { |
15 | 28 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/resnext101_32x4d-29e315fa.pth', |
|
34 | 47 | } |
35 | 48 | } |
36 | 49 |
|
| 50 | +class ResNeXt50_32x4d(nn.Module): |
| 51 | + |
| 52 | + def __init__(self, num_classes=1000): |
| 53 | + super(ResNeXt50_32x4d, self).__init__() |
| 54 | + self.num_classes = num_classes |
| 55 | + self.features = resnext50_32x4d_features |
| 56 | + self.avg_pool = nn.AvgPool2d((7, 7), (1, 1)) |
| 57 | + self.last_linear = nn.Linear(2048, num_classes) |
| 58 | + |
| 59 | + def logits(self, input): |
| 60 | + x = self.avg_pool(input) |
| 61 | + x = x.view(x.size(0), -1) |
| 62 | + x = self.last_linear(x) |
| 63 | + return x |
| 64 | + |
| 65 | + def forward(self, input): |
| 66 | + x = self.features(input) |
| 67 | + x = self.logits(x) |
| 68 | + return x |
| 69 | + |
| 70 | + |
37 | 71 | class ResNeXt101_32x4d(nn.Module): |
38 | 72 |
|
39 | 73 | def __init__(self, num_classes=1000): |
@@ -76,6 +110,20 @@ def forward(self, input): |
76 | 110 | return x |
77 | 111 |
|
78 | 112 |
|
| 113 | +def resnext50_32x4d(num_classes=1000, pretrained='imagenet'): |
| 114 | + model = ResNeXt50_32x4d(num_classes=num_classes) |
| 115 | + if pretrained is not None: |
| 116 | + settings = pretrained_settings['resnext50_32x4d'][pretrained] |
| 117 | + assert num_classes == settings['num_classes'], \ |
| 118 | + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) |
| 119 | + model.load_state_dict(model_zoo.load_url(settings['url'])) |
| 120 | + model.input_space = settings['input_space'] |
| 121 | + model.input_size = settings['input_size'] |
| 122 | + model.input_range = settings['input_range'] |
| 123 | + model.mean = settings['mean'] |
| 124 | + model.std = settings['std'] |
| 125 | + return model |
| 126 | + |
79 | 127 | def resnext101_32x4d(num_classes=1000, pretrained='imagenet'): |
80 | 128 | model = ResNeXt101_32x4d(num_classes=num_classes) |
81 | 129 | if pretrained is not None: |
|
0 commit comments