@@ -91,7 +91,7 @@ def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION):
9191
9292def _chainer_estimator (
9393 sagemaker_session ,
94- framework_version = defaults . CHAINER_VERSION ,
94+ framework_version ,
9595 train_instance_type = None ,
9696 base_job_name = None ,
9797 use_mpi = None ,
@@ -202,13 +202,14 @@ def _create_train_job_with_additional_hyperparameters(version):
202202 }
203203
204204
205- def test_additional_hyperparameters (sagemaker_session ):
205+ def test_additional_hyperparameters (sagemaker_session , chainer_version ):
206206 chainer = _chainer_estimator (
207207 sagemaker_session ,
208208 use_mpi = True ,
209209 num_processes = 4 ,
210210 process_slots_per_host = 10 ,
211211 additional_mpi_options = "-x MY_ENVIRONMENT_VARIABLE" ,
212+ framework_version = chainer_version ,
212213 )
213214 assert bool (strtobool (chainer .hyperparameters ()["sagemaker_use_mpi" ]))
214215 assert int (chainer .hyperparameters ()["sagemaker_num_processes" ]) == 4
@@ -300,7 +301,7 @@ def test_create_model(sagemaker_session, chainer_version):
300301 assert model .vpc_config is None
301302
302303
303- def test_create_model_with_optional_params (sagemaker_session ):
304+ def test_create_model_with_optional_params (sagemaker_session , chainer_version ):
304305 container_log_level = '"logging.INFO"'
305306 source_dir = "s3://mybucket/source"
306307 enable_cloudwatch_metrics = "true"
@@ -311,6 +312,7 @@ def test_create_model_with_optional_params(sagemaker_session):
311312 train_instance_count = INSTANCE_COUNT ,
312313 train_instance_type = INSTANCE_TYPE ,
313314 container_log_level = container_log_level ,
315+ framework_version = chainer_version ,
314316 py_version = PYTHON_VERSION ,
315317 base_job_name = "job" ,
316318 source_dir = source_dir ,
@@ -372,8 +374,8 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
372374 sagemaker_session = sagemaker_session ,
373375 train_instance_count = INSTANCE_COUNT ,
374376 train_instance_type = INSTANCE_TYPE ,
375- py_version = PYTHON_VERSION ,
376377 framework_version = chainer_version ,
378+ py_version = PYTHON_VERSION ,
377379 )
378380
379381 inputs = "s3://mybucket/train"
@@ -414,62 +416,72 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
414416
415417
416418@patch ("sagemaker.utils.create_tar_file" , MagicMock ())
417- def test_model (sagemaker_session ):
419+ def test_model (sagemaker_session , chainer_version ):
418420 model = ChainerModel (
419421 "s3://some/data.tar.gz" ,
420422 role = ROLE ,
421423 entry_point = SCRIPT_PATH ,
422424 sagemaker_session = sagemaker_session ,
425+ framework_version = chainer_version ,
426+ py_version = PYTHON_VERSION ,
423427 )
424428 predictor = model .deploy (1 , GPU )
425429 assert isinstance (predictor , ChainerPredictor )
426430
427431
428432@patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
429- def test_model_prepare_container_def_accelerator_error (sagemaker_session ):
433+ def test_model_prepare_container_def_accelerator_error (sagemaker_session , chainer_version ):
430434 model = ChainerModel (
431- MODEL_DATA , role = ROLE , entry_point = SCRIPT_PATH , sagemaker_session = sagemaker_session
435+ MODEL_DATA ,
436+ role = ROLE ,
437+ entry_point = SCRIPT_PATH ,
438+ sagemaker_session = sagemaker_session ,
439+ framework_version = chainer_version ,
440+ py_version = PYTHON_VERSION ,
432441 )
433442 with pytest .raises (ValueError ):
434443 model .prepare_container_def (INSTANCE_TYPE , accelerator_type = ACCELERATOR_TYPE )
435444
436445
437- def test_train_image_default (sagemaker_session ):
446+ def test_train_image_default (sagemaker_session , chainer_version ):
438447 chainer = Chainer (
439448 entry_point = SCRIPT_PATH ,
440449 role = ROLE ,
441450 sagemaker_session = sagemaker_session ,
442451 train_instance_count = INSTANCE_COUNT ,
443452 train_instance_type = INSTANCE_TYPE ,
453+ framework_version = chainer_version ,
444454 py_version = PYTHON_VERSION ,
445455 )
446456
447- assert _get_full_cpu_image_uri (defaults . CHAINER_VERSION ) in chainer .train_image ()
457+ assert _get_full_cpu_image_uri (chainer_version ) in chainer .train_image ()
448458
449459
450460def test_train_image_cpu_instances (sagemaker_session , chainer_version ):
451461 chainer = _chainer_estimator (
452- sagemaker_session , chainer_version , train_instance_type = "ml.c2.2xlarge"
462+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.c2.2xlarge"
453463 )
454464 assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version )
455465
456466 chainer = _chainer_estimator (
457- sagemaker_session , chainer_version , train_instance_type = "ml.c4.2xlarge"
467+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.c4.2xlarge"
458468 )
459469 assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version )
460470
461- chainer = _chainer_estimator (sagemaker_session , chainer_version , train_instance_type = "ml.m16" )
471+ chainer = _chainer_estimator (
472+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.m16"
473+ )
462474 assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version )
463475
464476
465477def test_train_image_gpu_instances (sagemaker_session , chainer_version ):
466478 chainer = _chainer_estimator (
467- sagemaker_session , chainer_version , train_instance_type = "ml.g2.2xlarge"
479+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.g2.2xlarge"
468480 )
469481 assert chainer .train_image () == _get_full_gpu_image_uri (chainer_version )
470482
471483 chainer = _chainer_estimator (
472- sagemaker_session , chainer_version , train_instance_type = "ml.p2.2xlarge"
484+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.p2.2xlarge"
473485 )
474486 assert chainer .train_image () == _get_full_gpu_image_uri (chainer_version )
475487
@@ -597,13 +609,14 @@ def test_attach_custom_image(sagemaker_session):
597609
598610
599611@patch ("sagemaker.chainer.estimator.python_deprecation_warning" )
600- def test_estimator_py2_warning (warning , sagemaker_session ):
612+ def test_estimator_py2_warning (warning , sagemaker_session , chainer_version ):
601613 estimator = Chainer (
602614 entry_point = SCRIPT_PATH ,
603615 role = ROLE ,
604616 sagemaker_session = sagemaker_session ,
605617 train_instance_count = INSTANCE_COUNT ,
606618 train_instance_type = INSTANCE_TYPE ,
619+ framework_version = chainer_version ,
607620 py_version = "py2" ,
608621 )
609622
@@ -612,49 +625,22 @@ def test_estimator_py2_warning(warning, sagemaker_session):
612625
613626
614627@patch ("sagemaker.chainer.model.python_deprecation_warning" )
615- def test_model_py2_warning (warning , sagemaker_session ):
628+ def test_model_py2_warning (warning , sagemaker_session , chainer_version ):
616629 model = ChainerModel (
617630 MODEL_DATA ,
618631 role = ROLE ,
619632 entry_point = SCRIPT_PATH ,
620633 sagemaker_session = sagemaker_session ,
634+ framework_version = chainer_version ,
621635 py_version = "py2" ,
622636 )
623637 assert model .py_version == "py2"
624638 warning .assert_called_with (model .__framework_name__ , defaults .LATEST_PY2_VERSION )
625639
626640
627- @patch ("sagemaker.chainer.estimator.empty_framework_version_warning" )
628- def test_empty_framework_version (warning , sagemaker_session ):
629- estimator = Chainer (
630- entry_point = SCRIPT_PATH ,
631- role = ROLE ,
632- sagemaker_session = sagemaker_session ,
633- train_instance_count = INSTANCE_COUNT ,
634- train_instance_type = INSTANCE_TYPE ,
635- framework_version = None ,
636- )
637-
638- assert estimator .framework_version == defaults .CHAINER_VERSION
639- warning .assert_called_with (defaults .CHAINER_VERSION , Chainer .LATEST_VERSION )
640-
641-
642- @patch ("sagemaker.chainer.model.empty_framework_version_warning" )
643- def test_model_empty_framework_version (warning , sagemaker_session ):
644- model = ChainerModel (
645- MODEL_DATA ,
646- role = ROLE ,
647- entry_point = SCRIPT_PATH ,
648- sagemaker_session = sagemaker_session ,
649- framework_version = None ,
650- )
651- assert model .framework_version == defaults .CHAINER_VERSION
652- warning .assert_called_with (defaults .CHAINER_VERSION , defaults .LATEST_VERSION )
653-
654-
655- def test_custom_image_estimator_deploy (sagemaker_session ):
641+ def test_custom_image_estimator_deploy (sagemaker_session , chainer_version ):
656642 custom_image = "mycustomimage:latest"
657- chainer = _chainer_estimator (sagemaker_session )
643+ chainer = _chainer_estimator (sagemaker_session , framework_version = chainer_version )
658644 chainer .fit (inputs = "s3://mybucket/train" , job_name = "new_name" )
659645 model = chainer .create_model (image = custom_image )
660646 assert model .image == custom_image
0 commit comments