33
44import pytest
55import test_models as TM
6- import torchvision
6+ from torchvision import models
77from torchvision .models ._api import WeightsEnum , Weights
88from torchvision .models ._utils import handle_legacy_interface
99
1010run_if_test_with_prototype = pytest .mark .skipif (
11- os .getenv ("PYTORCH_TEST_WITH_PROTOTYPE " ) != "1" ,
12- reason = "Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE =1 to run them." ,
11+ os .getenv ("PYTORCH_TEST_WITH_EXTENDED " ) != "1" ,
12+ reason = "Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED =1 to run them." ,
1313)
1414
1515
16- def _get_original_model (model_fn ):
17- original_module_name = model_fn .__module__ .replace (".prototype" , "" )
18- module = importlib .import_module (original_module_name )
19- return module .__dict__ [model_fn .__name__ ]
20-
21-
2216def _get_parent_module (model_fn ):
2317 parent_module_name = "." .join (model_fn .__module__ .split ("." )[:- 1 ])
2418 module = importlib .import_module (parent_module_name )
@@ -38,44 +32,33 @@ def _get_model_weights(model_fn):
3832 return None
3933
4034
41- def _build_model (fn , ** kwargs ):
42- try :
43- model = fn (** kwargs )
44- except ValueError as e :
45- msg = str (e )
46- if "No checkpoint is available" in msg :
47- pytest .skip (msg )
48- raise e
49- return model .eval ()
50-
51-
5235@pytest .mark .parametrize (
5336 "name, weight" ,
5437 [
55- ("ResNet50_Weights.IMAGENET1K_V1" , torchvision . models .ResNet50_Weights .IMAGENET1K_V1 ),
56- ("ResNet50_Weights.DEFAULT" , torchvision . models .ResNet50_Weights .IMAGENET1K_V2 ),
38+ ("ResNet50_Weights.IMAGENET1K_V1" , models .ResNet50_Weights .IMAGENET1K_V1 ),
39+ ("ResNet50_Weights.DEFAULT" , models .ResNet50_Weights .IMAGENET1K_V2 ),
5740 (
5841 "ResNet50_QuantizedWeights.DEFAULT" ,
59- torchvision . models .quantization .ResNet50_QuantizedWeights .IMAGENET1K_FBGEMM_V2 ,
42+ models .quantization .ResNet50_QuantizedWeights .IMAGENET1K_FBGEMM_V2 ,
6043 ),
6144 (
6245 "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1" ,
63- torchvision . models .quantization .ResNet50_QuantizedWeights .IMAGENET1K_FBGEMM_V1 ,
46+ models .quantization .ResNet50_QuantizedWeights .IMAGENET1K_FBGEMM_V1 ,
6447 ),
6548 ],
6649)
6750def test_get_weight (name , weight ):
68- assert torchvision . models .get_weight (name ) == weight
51+ assert models .get_weight (name ) == weight
6952
7053
7154@pytest .mark .parametrize (
7255 "model_fn" ,
73- TM .get_models_from_module (torchvision . models )
74- + TM .get_models_from_module (torchvision . models .detection )
75- + TM .get_models_from_module (torchvision . models .quantization )
76- + TM .get_models_from_module (torchvision . models .segmentation )
77- + TM .get_models_from_module (torchvision . models .video )
78- + TM .get_models_from_module (torchvision . models .optical_flow ),
56+ TM .get_models_from_module (models )
57+ + TM .get_models_from_module (models .detection )
58+ + TM .get_models_from_module (models .quantization )
59+ + TM .get_models_from_module (models .segmentation )
60+ + TM .get_models_from_module (models .video )
61+ + TM .get_models_from_module (models .optical_flow ),
7962)
8063def test_naming_conventions (model_fn ):
8164 weights_enum = _get_model_weights (model_fn )
@@ -86,12 +69,12 @@ def test_naming_conventions(model_fn):
8669
8770@pytest .mark .parametrize (
8871 "model_fn" ,
89- TM .get_models_from_module (torchvision . models )
90- + TM .get_models_from_module (torchvision . models .detection )
91- + TM .get_models_from_module (torchvision . models .quantization )
92- + TM .get_models_from_module (torchvision . models .segmentation )
93- + TM .get_models_from_module (torchvision . models .video )
94- + TM .get_models_from_module (torchvision . models .optical_flow ),
72+ TM .get_models_from_module (models )
73+ + TM .get_models_from_module (models .detection )
74+ + TM .get_models_from_module (models .quantization )
75+ + TM .get_models_from_module (models .segmentation )
76+ + TM .get_models_from_module (models .video )
77+ + TM .get_models_from_module (models .optical_flow ),
9578)
9679@run_if_test_with_prototype
9780def test_schema_meta_validation (model_fn ):
0 commit comments