Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,29 +94,47 @@ def test_naming_conventions(model_fn):
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
)
@run_if_test_with_prototype
def test_schema_meta_validation(model_fn):
classification_fields = ["size", "categories", "acc@1", "acc@5"]
defaults = {
"all": ["interpolation", "recipe"],
"all": ["task", "architecture", "publication_year", "interpolation", "recipe", "num_params"],
"models": classification_fields,
"detection": ["categories", "map"],
"quantization": classification_fields + ["backend", "quantization", "unquantized"],
"segmentation": ["categories", "mIoU", "acc"],
"video": classification_fields,
"optical_flow": [],
}
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2]
fields = set(defaults["all"] + defaults[module_name])

weights_enum = _get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")

problematic_weights = {}
incorrect_params = []
for w in weights_enum:
missing_fields = fields - set(w.meta.keys())
if missing_fields:
problematic_weights[w] = missing_fields
if w == weights_enum.default:
if module_name == "quantization":
# parametes() cound doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.default.meta.get("num_params"):
incorrect_params.append(w)

assert not problematic_weights
assert not incorrect_params


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
Expand Down
4 changes: 3 additions & 1 deletion torchvision/models/segmentation/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

class FCN(_SimpleSegmentationModel):
"""
Implements a Fully-Convolutional Network for semantic segmentation.
Implements FCN model from
`"Fully Convolutional Networks for Semantic Segmentation"
<https://arxiv.org/abs/1411.4038>`_.

Args:
backbone (nn.Module): the network used to compute the features for the model.
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class AlexNet_Weights(WeightsEnum):
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "AlexNet",
"publication_year": 2012,
"num_params": 61100840,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand Down
7 changes: 7 additions & 0 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def _densenet(


_COMMON_META = {
"task": "image_classification",
"architecture": "DenseNet",
"publication_year": 2016,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand All @@ -77,6 +80,7 @@ class DenseNet121_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 7978856,
"acc@1": 74.434,
"acc@5": 91.972,
},
Expand All @@ -90,6 +94,7 @@ class DenseNet161_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 28681000,
"acc@1": 77.138,
"acc@5": 93.560,
},
Expand All @@ -103,6 +108,7 @@ class DenseNet169_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 14149480,
"acc@1": 75.600,
"acc@5": 92.806,
},
Expand All @@ -116,6 +122,7 @@ class DenseNet201_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 20013928,
"acc@1": 76.896,
"acc@5": 93.370,
},
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@


_COMMON_META = {
"task": "image_object_detection",
"architecture": "FasterRCNN",
"publication_year": 2015,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
Expand All @@ -42,6 +45,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
transforms=CocoEval,
meta={
**_COMMON_META,
"num_params": 41755286,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
Expand All @@ -55,6 +59,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
transforms=CocoEval,
meta={
**_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
Expand All @@ -68,6 +73,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
transforms=CocoEval,
meta={
**_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
Expand Down
5 changes: 5 additions & 0 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@


_COMMON_META = {
"task": "image_object_detection",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image keypoint detection. I guess?

"architecture": "KeypointRCNN",
"publication_year": 2017,
"categories": _COCO_PERSON_CATEGORIES,
"keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
"interpolation": InterpolationMode.BILINEAR,
Expand All @@ -36,6 +39,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
transforms=CocoEval,
meta={
**_COMMON_META,
"num_params": 59137258,
"recipe": "https://github.com/pytorch/vision/issues/1606",
"map": 50.6,
"map_kp": 61.1,
Expand All @@ -46,6 +50,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
transforms=CocoEval,
meta={
**_COMMON_META,
"num_params": 59137258,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"map": 54.6,
"map_kp": 65.0,
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval,
meta={
"task": "image_object_detection",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unclear on what to put here.
Indeed, this model could also be categorized as image_instance_segmentation. Maybe having task be a list would be of help here?

Ensuring that this is adapted with paperswithcode categorization would be good as well https://paperswithcode.com/paper/mask-r-cnn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah same here. I'll leave as is for now and review in more detail on the future.

"architecture": "MaskRCNN",
"publication_year": 2017,
"num_params": 44401393,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=CocoEval,
meta={
"task": "image_object_detection",
"architecture": "RetinaNet",
"publication_year": 2017,
"num_params": 34014999,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class SSD300_VGG16_Weights(WeightsEnum):
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=CocoEval,
meta={
"task": "image_object_detection",
"architecture": "SSD",
"publication_year": 2015,
"num_params": 35641826,
"size": (300, 300),
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=CocoEval,
meta={
"task": "image_object_detection",
"architecture": "SSDLite",
"publication_year": 2018,
"num_params": 3440060,
"size": (320, 320),
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand Down
12 changes: 12 additions & 0 deletions torchvision/prototype/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def _efficientnet(


_COMMON_META = {
"task": "image_classification",
"architecture": "EfficientNet",
"publication_year": 2019,
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BICUBIC,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
Expand All @@ -75,6 +78,7 @@ class EfficientNet_B0_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
"num_params": 5288548,
"size": (224, 224),
"acc@1": 77.692,
"acc@5": 93.532,
Expand All @@ -89,6 +93,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
"num_params": 7794184,
"size": (240, 240),
"acc@1": 78.642,
"acc@5": 94.186,
Expand All @@ -99,6 +104,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR),
meta={
**_COMMON_META,
"num_params": 7794184,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
"interpolation": InterpolationMode.BILINEAR,
"size": (240, 240),
Expand All @@ -115,6 +121,7 @@ class EfficientNet_B2_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
"num_params": 9109994,
"size": (288, 288),
"acc@1": 80.608,
"acc@5": 95.310,
Expand All @@ -129,6 +136,7 @@ class EfficientNet_B3_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
"num_params": 12233232,
"size": (300, 300),
"acc@1": 82.008,
"acc@5": 96.054,
Expand All @@ -143,6 +151,7 @@ class EfficientNet_B4_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
"num_params": 19341616,
"size": (380, 380),
"acc@1": 83.384,
"acc@5": 96.594,
Expand All @@ -157,6 +166,7 @@ class EfficientNet_B5_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
"num_params": 30389784,
"size": (456, 456),
"acc@1": 83.444,
"acc@5": 96.628,
Expand All @@ -171,6 +181,7 @@ class EfficientNet_B6_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
"num_params": 43040704,
"size": (528, 528),
"acc@1": 84.008,
"acc@5": 96.916,
Expand All @@ -185,6 +196,7 @@ class EfficientNet_B7_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
"num_params": 66347960,
"size": (600, 600),
"acc@1": 84.122,
"acc@5": 96.908,
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class GoogLeNet_Weights(WeightsEnum):
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "GoogLeNet",
"publication_year": 2014,
"num_params": 6624904,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class Inception_V3_Weights(WeightsEnum):
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={
"task": "image_classification",
"architecture": "InceptionV3",
"publication_year": 2015,
"num_params": 27161264,
"size": (299, 299),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand Down
5 changes: 5 additions & 0 deletions torchvision/prototype/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@


_COMMON_META = {
"task": "image_classification",
"architecture": "MNASNet",
"publication_year": 2018,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand All @@ -37,6 +40,7 @@ class MNASNet0_5_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2218512,
"acc@1": 67.734,
"acc@5": 87.490,
},
Expand All @@ -55,6 +59,7 @@ class MNASNet1_0_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 4383312,
"acc@1": 73.456,
"acc@5": 91.510,
},
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class MobileNet_V2_Weights(WeightsEnum):
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV2",
"publication_year": 2018,
"num_params": 3504872,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def _mobilenet_v3(


_COMMON_META = {
"task": "image_classification",
"architecture": "MobileNetV3",
"publication_year": 2019,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
Expand All @@ -50,6 +53,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 5483032,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 74.042,
"acc@5": 91.340,
Expand All @@ -60,6 +64,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 5483032,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
"acc@1": 75.274,
"acc@5": 92.566,
Expand All @@ -74,6 +79,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2542856,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 67.668,
"acc@5": 87.402,
Expand Down
Loading