5252
5353@override_pipeline_parameter_var
5454def retrieve (
55- framework ,
56- region ,
57- version = None ,
58- py_version = None ,
59- instance_type = None ,
60- accelerator_type = None ,
61- image_scope = None ,
62- container_version = None ,
63- distribution = None ,
64- base_framework_version = None ,
65- training_compiler_config = None ,
66- model_id = None ,
67- model_version = None ,
68- hub_arn = None ,
69- tolerate_vulnerable_model = False ,
70- tolerate_deprecated_model = False ,
71- sdk_version = None ,
72- inference_tool = None ,
73- serverless_inference_config = None ,
74- sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
75- config_name = None ,
76- model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
55+ framework ,
56+ region ,
57+ version = None ,
58+ py_version = None ,
59+ instance_type = None ,
60+ accelerator_type = None ,
61+ image_scope = None ,
62+ container_version = None ,
63+ distribution = None ,
64+ base_framework_version = None ,
65+ training_compiler_config = None ,
66+ model_id = None ,
67+ model_version = None ,
68+ hub_arn = None ,
69+ tolerate_vulnerable_model = False ,
70+ tolerate_deprecated_model = False ,
71+ sdk_version = None ,
72+ inference_tool = None ,
73+ serverless_inference_config = None ,
74+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
75+ config_name = None ,
76+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
7777) -> str :
7878 """Retrieves the ECR URI for the Docker image matching the given arguments.
7979
@@ -250,10 +250,10 @@ def retrieve(
250250 if config .get ("version_aliases" ).get (original_version ):
251251 _version = config .get ("version_aliases" )[original_version ]
252252 if (
253- config .get ("versions" , {})
254- .get (_version , {})
255- .get ("version_aliases" , {})
256- .get (base_framework_version , {})
253+ config .get ("versions" , {})
254+ .get (_version , {})
255+ .get ("version_aliases" , {})
256+ .get (base_framework_version , {})
257257 ):
258258 _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
259259 base_framework_version
@@ -290,16 +290,16 @@ def retrieve(
290290
291291
292292def _get_image_tag (
293- container_version ,
294- distribution ,
295- final_image_scope ,
296- framework ,
297- inference_tool ,
298- instance_type ,
299- processor ,
300- py_version ,
301- tag_prefix ,
302- version ,
293+ container_version ,
294+ distribution ,
295+ final_image_scope ,
296+ framework ,
297+ inference_tool ,
298+ instance_type ,
299+ processor ,
300+ py_version ,
301+ tag_prefix ,
302+ version ,
303303):
304304 """Return image tag based on framework, container, and compute configuration(s)."""
305305 instance_type_family = utils .get_instance_type_family (instance_type )
@@ -311,8 +311,8 @@ def _get_image_tag(
311311 "instance type" ,
312312 )
313313 if (
314- instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315- or final_image_scope == INFERENCE_GRAVITON
314+ instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315+ or final_image_scope == INFERENCE_GRAVITON
316316 ):
317317 version_to_arm64_tag_mapping = {
318318 "xgboost" : {
@@ -330,7 +330,7 @@ def _get_image_tag(
330330 tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
331331
332332 if instance_type is not None and _should_auto_select_container_version (
333- instance_type , distribution
333+ instance_type , distribution
334334 ):
335335 container_versions = {
336336 "tensorflow-2.3-gpu-py37" : "cu110-ubuntu18.04-v3" ,
@@ -398,7 +398,7 @@ def _validate_instance_deprecation(framework, instance_type, version):
398398 """Check if instance type is deprecated for a certain framework with a certain version"""
399399 if utils .get_instance_type_family (instance_type ) == "p2" :
400400 if (framework == "pytorch" and Version (version ) >= Version ("1.13" )) or (
401- framework == "tensorflow" and Version (version ) >= Version ("2.12" )
401+ framework == "tensorflow" and Version (version ) >= Version ("2.12" )
402402 ):
403403 raise ValueError (
404404 "P2 instances have been deprecated for sagemaker jobs starting PyTorch 1.13 and TensorFlow 2.12"
@@ -411,17 +411,17 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
411411 """Validate if framework is supported for the instance_type"""
412412 # Validate for Trainium allowed frameworks
413413 if (
414- instance_type is not None
415- and "trn" in instance_type
416- and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
414+ instance_type is not None
415+ and "trn" in instance_type
416+ and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
417417 ):
418418 _validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium" )
419419
420420 # Validate for Graviton allowed frameowrks
421421 if (
422- instance_type is not None
423- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424- and framework not in GRAVITON_ALLOWED_FRAMEWORKS
422+ instance_type is not None
423+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424+ and framework not in GRAVITON_ALLOWED_FRAMEWORKS
425425 ):
426426 _validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
427427
@@ -436,8 +436,8 @@ def config_for_framework(framework):
436436def _get_final_image_scope (framework , instance_type , image_scope ):
437437 """Return final image scope based on provided framework and instance type."""
438438 if (
439- framework in GRAVITON_ALLOWED_FRAMEWORKS
440- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
439+ framework in GRAVITON_ALLOWED_FRAMEWORKS
440+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
441441 ):
442442 return INFERENCE_GRAVITON
443443 if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -635,16 +635,16 @@ def _format_tag(tag_prefix, processor, py_version, container_version, inference_
635635
636636@override_pipeline_parameter_var
637637def get_training_image_uri (
638- region ,
639- framework ,
640- framework_version = None ,
641- py_version = None ,
642- image_uri = None ,
643- distribution = None ,
644- compiler_config = None ,
645- tensorflow_version = None ,
646- pytorch_version = None ,
647- instance_type = None ,
638+ region ,
639+ framework ,
640+ framework_version = None ,
641+ py_version = None ,
642+ image_uri = None ,
643+ distribution = None ,
644+ compiler_config = None ,
645+ tensorflow_version = None ,
646+ pytorch_version = None ,
647+ instance_type = None ,
648648) -> str :
649649 """Retrieves the image URI for training.
650650
@@ -748,26 +748,28 @@ def get_base_python_image_uri(region, py_version="310") -> str:
748748 return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo_and_tag )
749749
750750
751- def get_latest_container_image (framework : str ,
752- image_scope : Optional [str ] = None ,
753- instance_type : Optional [str ] = None ,
754- py_version : Optional [str ] = None ,
755- region : str = "us-west-2" ,
756- version : Optional [str ] = None ,
757- accelerator_type = None ,
758- container_version = None ,
759- distribution = None ,
760- base_framework_version = None ,
761- training_compiler_config = None ,
762- model_id = None ,
763- model_version = None ,
764- hub_arn = None ,
765- sdk_version = None ,
766- inference_tool = None ,
767- serverless_inference_config = None ,
768- config_name = None ,
769- ) -> Tuple [str , str ]:
751+ def get_latest_container_image (
752+ framework : str ,
753+ image_scope : Optional [str ] = None ,
754+ instance_type : Optional [str ] = None ,
755+ py_version : Optional [str ] = None ,
756+ region : str = "us-west-2" ,
757+ version : Optional [str ] = None ,
758+ accelerator_type = None ,
759+ container_version = None ,
760+ distribution = None ,
761+ base_framework_version = None ,
762+ training_compiler_config = None ,
763+ model_id = None ,
764+ model_version = None ,
765+ hub_arn = None ,
766+ sdk_version = None ,
767+ inference_tool = None ,
768+ serverless_inference_config = None ,
769+ config_name = None ,
770+ ) -> Tuple [str , str ]:
770771 """Retrieves the latest container image URI
772+
771773 Args:
772774 framework (str): The name of the framework or algorithm.
773775 image_scope (str): The image type, i.e. what it is used for.
@@ -818,31 +820,34 @@ def get_latest_container_image(framework: str,
818820
819821 if not version :
820822 version = _fetch_latest_version_from_config (framework_config , image_scope )
821- image_uri = retrieve (framework = framework ,
822- region = region ,
823- version = version ,
824- instance_type = instance_type ,
825- py_version = py_version ,
826- accelerator_type = accelerator_type ,
827- image_scope = image_scope ,
828- container_version = container_version ,
829- distribution = distribution ,
830- base_framework_version = base_framework_version ,
831- training_compiler_config = training_compiler_config ,
832- model_id = model_id ,
833- model_version = model_version ,
834- hub_arn = hub_arn ,
835- sdk_version = sdk_version ,
836- inference_tool = inference_tool ,
837- serverless_inference_config = serverless_inference_config ,
838- config_name = config_name
839- )
823+ image_uri = retrieve (
824+ framework = framework ,
825+ region = region ,
826+ version = version ,
827+ instance_type = instance_type ,
828+ py_version = py_version ,
829+ accelerator_type = accelerator_type ,
830+ image_scope = image_scope ,
831+ container_version = container_version ,
832+ distribution = distribution ,
833+ base_framework_version = base_framework_version ,
834+ training_compiler_config = training_compiler_config ,
835+ model_id = model_id ,
836+ model_version = model_version ,
837+ hub_arn = hub_arn ,
838+ sdk_version = sdk_version ,
839+ inference_tool = inference_tool ,
840+ serverless_inference_config = serverless_inference_config ,
841+ config_name = config_name ,
842+ )
840843 return image_uri , version
841844
842845
843- def _fetch_latest_version_from_config (framework_config : dict ,
844- image_scope : Optional [str ] = None ) -> Optional [str ]:
845- """ Helper function to fetch the latest version as a string from a framework's config
846+ def _fetch_latest_version_from_config (
847+ framework_config : dict , image_scope : Optional [str ] = None
848+ ) -> Optional [str ]:
849+ """Helper function to fetch the latest version as a string from a framework's config
850+
846851 Args:
847852 framework_config (dict): A framework config dict.
848853 image_scope (str): Scope of the image, eg: training, inference
@@ -863,8 +868,11 @@ def _fetch_latest_version_from_config(framework_config: dict,
863868 bottom_version = versions [- 1 ]
864869 if top_version == "latest" or bottom_version == "latest" :
865870 return None
866- elif (image_scope is not None and image_scope in framework_config
867- and "versions" in framework_config [image_scope ]):
871+ elif (
872+ image_scope is not None
873+ and image_scope in framework_config
874+ and "versions" in framework_config [image_scope ]
875+ ):
868876 versions = list (framework_config [image_scope ]["versions" ].keys ())
869877 top_version = versions [0 ]
870878 bottom_version = versions [- 1 ]
0 commit comments