Skip to content

Commit c11240f

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Improved meta-data for models (#5170)
Summary: * Improved meta-data for models. * Addressing comments from code-review. * Add parameter count. * Fix linter. Reviewed By: sallysyw Differential Revision: D33479281 fbshipit-source-id: 7a133324ed5a289a0ac89522b0d4a38ce8b201e0
1 parent 1633844 commit c11240f

33 files changed

+289
-7
lines changed

test/test_prototype_models.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,29 +94,47 @@ def test_naming_conventions(model_fn):
9494
+ TM.get_models_from_module(models.video)
9595
+ TM.get_models_from_module(models.optical_flow),
9696
)
97+
@run_if_test_with_prototype
9798
def test_schema_meta_validation(model_fn):
9899
classification_fields = ["size", "categories", "acc@1", "acc@5"]
99100
defaults = {
100-
"all": ["interpolation", "recipe"],
101+
"all": ["task", "architecture", "publication_year", "interpolation", "recipe", "num_params"],
101102
"models": classification_fields,
102103
"detection": ["categories", "map"],
103104
"quantization": classification_fields + ["backend", "quantization", "unquantized"],
104105
"segmentation": ["categories", "mIoU", "acc"],
105106
"video": classification_fields,
106107
"optical_flow": [],
107108
}
109+
model_name = model_fn.__name__
108110
module_name = model_fn.__module__.split(".")[-2]
109111
fields = set(defaults["all"] + defaults[module_name])
110112

111113
weights_enum = _get_model_weights(model_fn)
114+
if len(weights_enum) == 0:
115+
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
112116

113117
problematic_weights = {}
118+
incorrect_params = []
114119
for w in weights_enum:
115120
missing_fields = fields - set(w.meta.keys())
116121
if missing_fields:
117122
problematic_weights[w] = missing_fields
123+
if w == weights_enum.default:
124+
if module_name == "quantization":
125+
# parametes() cound doesn't work well with quantization, so we check against the non-quantized
126+
unquantized_w = w.meta.get("unquantized")
127+
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
128+
incorrect_params.append(w)
129+
else:
130+
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
131+
incorrect_params.append(w)
132+
else:
133+
if w.meta.get("num_params") != weights_enum.default.meta.get("num_params"):
134+
incorrect_params.append(w)
118135

119136
assert not problematic_weights
137+
assert not incorrect_params
120138

121139

122140
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))

torchvision/models/segmentation/fcn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
class FCN(_SimpleSegmentationModel):
2020
"""
21-
Implements a Fully-Convolutional Network for semantic segmentation.
21+
Implements FCN model from
22+
`"Fully Convolutional Networks for Semantic Segmentation"
23+
<https://arxiv.org/abs/1411.4038>`_.
2224
2325
Args:
2426
backbone (nn.Module): the network used to compute the features for the model.

torchvision/prototype/models/alexnet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class AlexNet_Weights(WeightsEnum):
1818
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
1919
transforms=partial(ImageNetEval, crop_size=224),
2020
meta={
21+
"task": "image_classification",
22+
"architecture": "AlexNet",
23+
"publication_year": 2012,
24+
"num_params": 61100840,
2125
"size": (224, 224),
2226
"categories": _IMAGENET_CATEGORIES,
2327
"interpolation": InterpolationMode.BILINEAR,

torchvision/prototype/models/densenet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def _densenet(
6464

6565

6666
_COMMON_META = {
67+
"task": "image_classification",
68+
"architecture": "DenseNet",
69+
"publication_year": 2016,
6770
"size": (224, 224),
6871
"categories": _IMAGENET_CATEGORIES,
6972
"interpolation": InterpolationMode.BILINEAR,
@@ -77,6 +80,7 @@ class DenseNet121_Weights(WeightsEnum):
7780
transforms=partial(ImageNetEval, crop_size=224),
7881
meta={
7982
**_COMMON_META,
83+
"num_params": 7978856,
8084
"acc@1": 74.434,
8185
"acc@5": 91.972,
8286
},
@@ -90,6 +94,7 @@ class DenseNet161_Weights(WeightsEnum):
9094
transforms=partial(ImageNetEval, crop_size=224),
9195
meta={
9296
**_COMMON_META,
97+
"num_params": 28681000,
9398
"acc@1": 77.138,
9499
"acc@5": 93.560,
95100
},
@@ -103,6 +108,7 @@ class DenseNet169_Weights(WeightsEnum):
103108
transforms=partial(ImageNetEval, crop_size=224),
104109
meta={
105110
**_COMMON_META,
111+
"num_params": 14149480,
106112
"acc@1": 75.600,
107113
"acc@5": 92.806,
108114
},
@@ -116,6 +122,7 @@ class DenseNet201_Weights(WeightsEnum):
116122
transforms=partial(ImageNetEval, crop_size=224),
117123
meta={
118124
**_COMMON_META,
125+
"num_params": 20013928,
119126
"acc@1": 76.896,
120127
"acc@5": 93.370,
121128
},

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131

3232

3333
_COMMON_META = {
34+
"task": "image_object_detection",
35+
"architecture": "FasterRCNN",
36+
"publication_year": 2015,
3437
"categories": _COCO_CATEGORIES,
3538
"interpolation": InterpolationMode.BILINEAR,
3639
}
@@ -42,6 +45,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
4245
transforms=CocoEval,
4346
meta={
4447
**_COMMON_META,
48+
"num_params": 41755286,
4549
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
4650
"map": 37.0,
4751
},
@@ -55,6 +59,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
5559
transforms=CocoEval,
5660
meta={
5761
**_COMMON_META,
62+
"num_params": 19386354,
5863
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
5964
"map": 32.8,
6065
},
@@ -68,6 +73,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
6873
transforms=CocoEval,
6974
meta={
7075
**_COMMON_META,
76+
"num_params": 19386354,
7177
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
7278
"map": 22.8,
7379
},

torchvision/prototype/models/detection/keypoint_rcnn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525

2626
_COMMON_META = {
27+
"task": "image_object_detection",
28+
"architecture": "KeypointRCNN",
29+
"publication_year": 2017,
2730
"categories": _COCO_PERSON_CATEGORIES,
2831
"keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
2932
"interpolation": InterpolationMode.BILINEAR,
@@ -36,6 +39,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
3639
transforms=CocoEval,
3740
meta={
3841
**_COMMON_META,
42+
"num_params": 59137258,
3943
"recipe": "https://github.com/pytorch/vision/issues/1606",
4044
"map": 50.6,
4145
"map_kp": 61.1,
@@ -46,6 +50,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
4650
transforms=CocoEval,
4751
meta={
4852
**_COMMON_META,
53+
"num_params": 59137258,
4954
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
5055
"map": 54.6,
5156
"map_kp": 65.0,

torchvision/prototype/models/detection/mask_rcnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
2828
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
2929
transforms=CocoEval,
3030
meta={
31+
"task": "image_object_detection",
32+
"architecture": "MaskRCNN",
33+
"publication_year": 2017,
34+
"num_params": 44401393,
3135
"categories": _COCO_CATEGORIES,
3236
"interpolation": InterpolationMode.BILINEAR,
3337
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",

torchvision/prototype/models/detection/retinanet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
2929
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
3030
transforms=CocoEval,
3131
meta={
32+
"task": "image_object_detection",
33+
"architecture": "RetinaNet",
34+
"publication_year": 2017,
35+
"num_params": 34014999,
3236
"categories": _COCO_CATEGORIES,
3337
"interpolation": InterpolationMode.BILINEAR,
3438
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",

torchvision/prototype/models/detection/ssd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ class SSD300_VGG16_Weights(WeightsEnum):
2727
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
2828
transforms=CocoEval,
2929
meta={
30+
"task": "image_object_detection",
31+
"architecture": "SSD",
32+
"publication_year": 2015,
33+
"num_params": 35641826,
3034
"size": (300, 300),
3135
"categories": _COCO_CATEGORIES,
3236
"interpolation": InterpolationMode.BILINEAR,

torchvision/prototype/models/detection/ssdlite.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
3232
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
3333
transforms=CocoEval,
3434
meta={
35+
"task": "image_object_detection",
36+
"architecture": "SSDLite",
37+
"publication_year": 2018,
38+
"num_params": 3440060,
3539
"size": (320, 320),
3640
"categories": _COCO_CATEGORIES,
3741
"interpolation": InterpolationMode.BILINEAR,

0 commit comments

Comments
 (0)