Skip to content

Commit 7d2b12e

Browse files
committed
Move prototype to extended tests
1 parent e67b676 commit 7d2b12e

File tree

3 files changed

+56
-51
lines changed

3 files changed

+56
-51
lines changed

.circleci/config.yml

Lines changed: 18 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ jobs:
335335
file_or_dir: test/test_onnx.py
336336

337337
unittest_prototype:
338+
docker:
339+
- image: circleci/python:3.7
340+
resource_class: xlarge
341+
steps:
342+
- checkout
343+
- install_torchvision
344+
- install_prototype_dependencies
345+
- pip_install:
346+
args: scipy pycocotools h5py
347+
descr: Install optional dependencies
348+
- run_tests_selective:
349+
file_or_dir: test/test_prototype_*.py
350+
351+
unittest_extended:
338352
docker:
339353
- image: circleci/python:3.7
340354
resource_class: xlarge
@@ -349,15 +363,11 @@ jobs:
349363
python scripts/collect_model_urls.py torchvision/prototype/models \
350364
| parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci'
351365
- install_torchvision
352-
- install_prototype_dependencies
353-
- pip_install:
354-
args: scipy pycocotools h5py
355-
descr: Install optional dependencies
356366
- run:
357-
name: Enable prototype tests
358-
command: echo 'export PYTORCH_TEST_WITH_PROTOTYPE=1' >> $BASH_ENV
367+
name: Enable extended tests
368+
command: echo 'export PYTORCH_TEST_WITH_EXTENDED=1' >> $BASH_ENV
359369
- run_tests_selective:
360-
file_or_dir: test/test_prototype_*.py
370+
file_or_dir: test/test_extended_*.py
361371

362372
binary_linux_wheel:
363373
<<: *binary_common
@@ -1094,6 +1104,7 @@ workflows:
10941104
- unittest_torchhub
10951105
- unittest_onnx
10961106
- unittest_prototype
1107+
- unittest_extended
10971108
{{ unittest_workflows() }}
10981109

10991110
cmake:

test/test_prototype_models.py renamed to test/test_extended_models.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,16 @@
33

44
import pytest
55
import test_models as TM
6-
import torchvision
6+
from torchvision import models
77
from torchvision.models._api import WeightsEnum, Weights
88
from torchvision.models._utils import handle_legacy_interface
99

1010
run_if_test_with_prototype = pytest.mark.skipif(
11-
os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1",
12-
reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
11+
os.getenv("PYTORCH_TEST_WITH_EXTENDED") != "1",
12+
reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
1313
)
1414

1515

16-
def _get_original_model(model_fn):
17-
original_module_name = model_fn.__module__.replace(".prototype", "")
18-
module = importlib.import_module(original_module_name)
19-
return module.__dict__[model_fn.__name__]
20-
21-
2216
def _get_parent_module(model_fn):
2317
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
2418
module = importlib.import_module(parent_module_name)
@@ -38,44 +32,33 @@ def _get_model_weights(model_fn):
3832
return None
3933

4034

41-
def _build_model(fn, **kwargs):
42-
try:
43-
model = fn(**kwargs)
44-
except ValueError as e:
45-
msg = str(e)
46-
if "No checkpoint is available" in msg:
47-
pytest.skip(msg)
48-
raise e
49-
return model.eval()
50-
51-
5235
@pytest.mark.parametrize(
5336
"name, weight",
5437
[
55-
("ResNet50_Weights.IMAGENET1K_V1", torchvision.models.ResNet50_Weights.IMAGENET1K_V1),
56-
("ResNet50_Weights.DEFAULT", torchvision.models.ResNet50_Weights.IMAGENET1K_V2),
38+
("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
39+
("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
5740
(
5841
"ResNet50_QuantizedWeights.DEFAULT",
59-
torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
42+
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
6043
),
6144
(
6245
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
63-
torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
46+
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
6447
),
6548
],
6649
)
6750
def test_get_weight(name, weight):
68-
assert torchvision.models.get_weight(name) == weight
51+
assert models.get_weight(name) == weight
6952

7053

7154
@pytest.mark.parametrize(
7255
"model_fn",
73-
TM.get_models_from_module(torchvision.models)
74-
+ TM.get_models_from_module(torchvision.models.detection)
75-
+ TM.get_models_from_module(torchvision.models.quantization)
76-
+ TM.get_models_from_module(torchvision.models.segmentation)
77-
+ TM.get_models_from_module(torchvision.models.video)
78-
+ TM.get_models_from_module(torchvision.models.optical_flow),
56+
TM.get_models_from_module(models)
57+
+ TM.get_models_from_module(models.detection)
58+
+ TM.get_models_from_module(models.quantization)
59+
+ TM.get_models_from_module(models.segmentation)
60+
+ TM.get_models_from_module(models.video)
61+
+ TM.get_models_from_module(models.optical_flow),
7962
)
8063
def test_naming_conventions(model_fn):
8164
weights_enum = _get_model_weights(model_fn)
@@ -86,12 +69,12 @@ def test_naming_conventions(model_fn):
8669

8770
@pytest.mark.parametrize(
8871
"model_fn",
89-
TM.get_models_from_module(torchvision.models)
90-
+ TM.get_models_from_module(torchvision.models.detection)
91-
+ TM.get_models_from_module(torchvision.models.quantization)
92-
+ TM.get_models_from_module(torchvision.models.segmentation)
93-
+ TM.get_models_from_module(torchvision.models.video)
94-
+ TM.get_models_from_module(torchvision.models.optical_flow),
72+
TM.get_models_from_module(models)
73+
+ TM.get_models_from_module(models.detection)
74+
+ TM.get_models_from_module(models.quantization)
75+
+ TM.get_models_from_module(models.segmentation)
76+
+ TM.get_models_from_module(models.video)
77+
+ TM.get_models_from_module(models.optical_flow),
9578
)
9679
@run_if_test_with_prototype
9780
def test_schema_meta_validation(model_fn):

0 commit comments

Comments
 (0)