1515import numpy
1616import os
1717import pytest
18- from sagemaker . pytorch . defaults import LATEST_PY2_VERSION
18+
1919from sagemaker .pytorch .estimator import PyTorch
2020from sagemaker .pytorch .model import PyTorchModel
2121from sagemaker .utils import sagemaker_timestamp
22-
2322from tests .integ import (
2423 test_region ,
2524 DATA_DIR ,
26- PYTHON_VERSION ,
2725 TRAINING_DEFAULT_TIMEOUT_MINUTES ,
2826 EI_SUPPORTED_REGIONS ,
2927)
3937
4038
4139@pytest .fixture (scope = "module" , name = "pytorch_training_job" )
42- def fixture_training_job (sagemaker_session , pytorch_full_version , cpu_instance_type ):
40+ def fixture_training_job (
41+ sagemaker_session , pytorch_full_version , pytorch_full_py_version , cpu_instance_type
42+ ):
4343 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
44- pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , cpu_instance_type )
44+ pytorch = _get_pytorch_estimator (
45+ sagemaker_session , pytorch_full_version , pytorch_full_py_version , cpu_instance_type
46+ )
4547
4648 pytorch .fit ({"training" : _upload_training_data (pytorch )})
4749 return pytorch .latest_training_job .name
4850
4951
5052@pytest .mark .canary_quick
5153@pytest .mark .regional_testing
52- @pytest .mark .skipif (
53- PYTHON_VERSION == "py2" ,
54- reason = "Python 2 is supported by PyTorch {} and lower versions." .format (LATEST_PY2_VERSION ),
55- )
56- def test_sync_fit_deploy (pytorch_training_job , sagemaker_session , cpu_instance_type ):
57- # TODO: add tests against local mode when it's ready to be used
54+ def test_fit_deploy (pytorch_training_job , sagemaker_session , cpu_instance_type ):
5855 endpoint_name = "test-pytorch-sync-fit-attach-deploy{}" .format (sagemaker_timestamp ())
5956 with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
6057 estimator = PyTorch .attach (pytorch_training_job , sagemaker_session = sagemaker_session )
@@ -70,16 +67,12 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_t
7067
7168
7269@pytest .mark .local_mode
73- @pytest .mark .skipif (
74- PYTHON_VERSION == "py2" ,
75- reason = "Python 2 is supported by PyTorch {} and lower versions." .format (LATEST_PY2_VERSION ),
76- )
77- def test_fit_deploy (sagemaker_local_session , pytorch_full_version ):
70+ def test_local_fit_deploy (sagemaker_local_session , pytorch_full_version , pytorch_full_py_version ):
7871 pytorch = PyTorch (
7972 entry_point = MNIST_SCRIPT ,
8073 role = "SageMakerRole" ,
8174 framework_version = pytorch_full_version ,
82- py_version = "py3" ,
75+ py_version = pytorch_full_py_version ,
8376 train_instance_count = 1 ,
8477 train_instance_type = "local" ,
8578 sagemaker_session = sagemaker_local_session ,
@@ -99,7 +92,11 @@ def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
9992
10093
10194def test_deploy_model (
102- pytorch_training_job , sagemaker_session , cpu_instance_type , pytorch_full_version
95+ pytorch_training_job ,
96+ sagemaker_session ,
97+ cpu_instance_type ,
98+ pytorch_full_version ,
99+ pytorch_full_py_version ,
103100):
104101 endpoint_name = "test-pytorch-deploy-model-{}" .format (sagemaker_timestamp ())
105102
@@ -113,7 +110,7 @@ def test_deploy_model(
113110 "SageMakerRole" ,
114111 entry_point = MNIST_SCRIPT ,
115112 framework_version = pytorch_full_version ,
116- py_version = "py3" ,
113+ py_version = pytorch_full_py_version ,
117114 sagemaker_session = sagemaker_session ,
118115 )
119116 predictor = model .deploy (1 , cpu_instance_type , endpoint_name = endpoint_name )
@@ -125,7 +122,9 @@ def test_deploy_model(
125122 assert output .shape == (batch_size , 10 )
126123
127124
128- def test_deploy_packed_model_with_entry_point_name (sagemaker_session , cpu_instance_type ):
125+ def test_deploy_packed_model_with_entry_point_name (
126+ sagemaker_session , cpu_instance_type , pytorch_full_version , pytorch_full_py_version
127+ ):
129128 endpoint_name = "test-pytorch-deploy-model-{}" .format (sagemaker_timestamp ())
130129
131130 with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
@@ -134,8 +133,8 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
134133 model_data ,
135134 "SageMakerRole" ,
136135 entry_point = "mnist.py" ,
137- framework_version = "1.4.0" ,
138- py_version = "py3" ,
136+ framework_version = pytorch_full_version ,
137+ py_version = pytorch_full_py_version ,
139138 sagemaker_session = sagemaker_session ,
140139 )
141140 predictor = model .deploy (1 , cpu_instance_type , endpoint_name = endpoint_name )
@@ -147,19 +146,20 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
147146 assert output .shape == (batch_size , 10 )
148147
149148
150- @pytest .mark .skipif (PYTHON_VERSION == "py2" , reason = "PyTorch EIA does not support Python 2." )
151149@pytest .mark .skipif (
152150 test_region () not in EI_SUPPORTED_REGIONS , reason = "EI isn't supported in that specific region."
153151)
154- def test_deploy_model_with_accelerator (sagemaker_session , cpu_instance_type ):
152+ def test_deploy_model_with_accelerator (
153+ sagemaker_session , cpu_instance_type , pytorch_full_ei_version , pytorch_full_py_version
154+ ):
155155 endpoint_name = "test-pytorch-deploy-eia-{}" .format (sagemaker_timestamp ())
156156 model_data = sagemaker_session .upload_data (path = EIA_MODEL )
157157 pytorch = PyTorchModel (
158158 model_data ,
159159 "SageMakerRole" ,
160160 entry_point = EIA_SCRIPT ,
161- framework_version = "1.3.1" ,
162- py_version = "py3" ,
161+ framework_version = pytorch_full_ei_version ,
162+ py_version = pytorch_full_py_version ,
163163 sagemaker_session = sagemaker_session ,
164164 )
165165 with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
@@ -185,13 +185,13 @@ def _upload_training_data(pytorch):
185185
186186
187187def _get_pytorch_estimator (
188- sagemaker_session , pytorch_full_version , instance_type , entry_point = MNIST_SCRIPT
188+ sagemaker_session , pytorch_version , py_version , instance_type , entry_point = MNIST_SCRIPT
189189):
190190 return PyTorch (
191191 entry_point = entry_point ,
192192 role = "SageMakerRole" ,
193- framework_version = pytorch_full_version ,
194- py_version = "py3" ,
193+ framework_version = pytorch_version ,
194+ py_version = py_version ,
195195 train_instance_count = 1 ,
196196 train_instance_type = instance_type ,
197197 sagemaker_session = sagemaker_session ,
0 commit comments