1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
1414
15- import gzip
1615import os
17- import pickle
18- import sys
19- import pytest
20- import tests .integ
2116
17+ import airflow
18+ import pytest
2219import numpy as np
20+ from airflow import DAG
21+ from airflow .contrib .operators .sagemaker_training_operator import SageMakerTrainingOperator
22+ from airflow .contrib .operators .sagemaker_transform_operator import SageMakerTransformOperator
23+ from six .moves .urllib .parse import urlparse
2324
25+ import tests .integ
2426from sagemaker import (
2527 KMeans ,
2628 FactorizationMachines ,
4042from sagemaker .pytorch .estimator import PyTorch
4143from sagemaker .sklearn import SKLearn
4244from sagemaker .tensorflow import TensorFlow
43- from sagemaker .workflow import airflow as sm_airflow
4445from sagemaker .utils import sagemaker_timestamp
45-
46- import airflow
47- from airflow import DAG
48- from airflow .contrib .operators .sagemaker_training_operator import SageMakerTrainingOperator
49- from airflow .contrib .operators .sagemaker_transform_operator import SageMakerTransformOperator
50-
46+ from sagemaker .workflow import airflow as sm_airflow
5147from sagemaker .xgboost import XGBoost
52- from tests .integ import DATA_DIR , PYTHON_VERSION
48+ from tests .integ import datasets , DATA_DIR , PYTHON_VERSION
5349from tests .integ .record_set import prepare_record_set_from_local_files
5450from tests .integ .timeout import timeout
5551
56- from six .moves .urllib .parse import urlparse
57-
5852PYTORCH_MNIST_DIR = os .path .join (DATA_DIR , "pytorch_mnist" )
5953PYTORCH_MNIST_SCRIPT = os .path .join (PYTORCH_MNIST_DIR , "mnist.py" )
6054AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS = 10
@@ -101,13 +95,6 @@ def test_byo_airflow_config_uploads_data_source_to_s3_when_inputs_provided(
10195@pytest .mark .canary_quick
10296def test_kmeans_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
10397 with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
104- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
105- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
106-
107- # Load the data into memory as numpy arrays
108- with gzip .open (data_path , "rb" ) as f :
109- train_set , _ , _ = pickle .load (f , ** pickle_args )
110-
11198 kmeans = KMeans (
11299 role = ROLE ,
113100 train_instance_count = SINGLE_INSTANCE_COUNT ,
@@ -126,7 +113,7 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
126113 kmeans .center_factor = 1
127114 kmeans .eval_metrics = ["ssd" , "msd" ]
128115
129- records = kmeans .record_set (train_set [0 ][:100 ])
116+ records = kmeans .record_set (datasets . one_p_mnist () [0 ][:100 ])
130117
131118 training_config = _build_airflow_workflow (
132119 estimator = kmeans , instance_type = cpu_instance_type , inputs = records
@@ -140,13 +127,6 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
140127
141128def test_fm_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
142129 with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
143- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
144- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
145-
146- # Load the data into memory as numpy arrays
147- with gzip .open (data_path , "rb" ) as f :
148- train_set , _ , _ = pickle .load (f , ** pickle_args )
149-
150130 fm = FactorizationMachines (
151131 role = ROLE ,
152132 train_instance_count = SINGLE_INSTANCE_COUNT ,
@@ -160,7 +140,8 @@ def test_fm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
160140 sagemaker_session = sagemaker_session ,
161141 )
162142
163- records = fm .record_set (train_set [0 ][:200 ], train_set [1 ][:200 ].astype ("float32" ))
143+ training_set = datasets .one_p_mnist ()
144+ records = fm .record_set (training_set [0 ][:200 ], training_set [1 ][:200 ].astype ("float32" ))
164145
165146 training_config = _build_airflow_workflow (
166147 estimator = fm , instance_type = cpu_instance_type , inputs = records
@@ -206,13 +187,6 @@ def test_ipinsights_airflow_config_uploads_data_source_to_s3(sagemaker_session,
206187
207188def test_knn_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
208189 with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
209- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
210- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
211-
212- # Load the data into memory as numpy arrays
213- with gzip .open (data_path , "rb" ) as f :
214- train_set , _ , _ = pickle .load (f , ** pickle_args )
215-
216190 knn = KNN (
217191 role = ROLE ,
218192 train_instance_count = SINGLE_INSTANCE_COUNT ,
@@ -223,7 +197,8 @@ def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
223197 sagemaker_session = sagemaker_session ,
224198 )
225199
226- records = knn .record_set (train_set [0 ][:200 ], train_set [1 ][:200 ].astype ("float32" ))
200+ training_set = datasets .one_p_mnist ()
201+ records = knn .record_set (training_set [0 ][:200 ], training_set [1 ][:200 ].astype ("float32" ))
227202
228203 training_config = _build_airflow_workflow (
229204 estimator = knn , instance_type = cpu_instance_type , inputs = records
@@ -277,16 +252,10 @@ def test_linearlearner_airflow_config_uploads_data_source_to_s3(
277252 sagemaker_session , cpu_instance_type
278253):
279254 with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
280- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
281- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
282-
283- # Load the data into memory as numpy arrays
284- with gzip .open (data_path , "rb" ) as f :
285- train_set , _ , _ = pickle .load (f , ** pickle_args )
286-
287- train_set [1 ][:100 ] = 1
288- train_set [1 ][100 :200 ] = 0
289- train_set = train_set [0 ], train_set [1 ].astype (np .dtype ("float32" ))
255+ training_set = datasets .one_p_mnist ()
256+ training_set [1 ][:100 ] = 1
257+ training_set [1 ][100 :200 ] = 0
258+ training_set = training_set [0 ], training_set [1 ].astype (np .dtype ("float32" ))
290259
291260 ll = LinearLearner (
292261 ROLE ,
@@ -331,7 +300,7 @@ def test_linearlearner_airflow_config_uploads_data_source_to_s3(
331300 ll .early_stopping_tolerance = 0.0001
332301 ll .early_stopping_patience = 3
333302
334- records = ll .record_set (train_set [0 ][:200 ], train_set [1 ][:200 ])
303+ records = ll .record_set (training_set [0 ][:200 ], training_set [1 ][:200 ])
335304
336305 training_config = _build_airflow_workflow (
337306 estimator = ll , instance_type = cpu_instance_type , inputs = records
@@ -380,13 +349,6 @@ def test_ntm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
380349@pytest .mark .canary_quick
381350def test_pca_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
382351 with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
383- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
384- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
385-
386- # Load the data into memory as numpy arrays
387- with gzip .open (data_path , "rb" ) as f :
388- train_set , _ , _ = pickle .load (f , ** pickle_args )
389-
390352 pca = PCA (
391353 role = ROLE ,
392354 train_instance_count = SINGLE_INSTANCE_COUNT ,
@@ -399,7 +361,7 @@ def test_pca_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
399361 pca .subtract_mean = True
400362 pca .extra_components = 5
401363
402- records = pca .record_set (train_set [0 ][:100 ])
364+ records = pca .record_set (datasets . one_p_mnist () [0 ][:100 ])
403365
404366 training_config = _build_airflow_workflow (
405367 estimator = pca , instance_type = cpu_instance_type , inputs = records
0 commit comments