1717import logging
1818import os
1919import re
20- from typing import Optional
20+ from typing import Optional , Tuple
2121from packaging .version import Version
2222
2323from sagemaker import utils
5252
5353@override_pipeline_parameter_var
5454def retrieve (
55- framework ,
56- region ,
57- version = None ,
58- py_version = None ,
59- instance_type = None ,
60- accelerator_type = None ,
61- image_scope = None ,
62- container_version = None ,
63- distribution = None ,
64- base_framework_version = None ,
65- training_compiler_config = None ,
66- model_id = None ,
67- model_version = None ,
68- hub_arn = None ,
69- tolerate_vulnerable_model = False ,
70- tolerate_deprecated_model = False ,
71- sdk_version = None ,
72- inference_tool = None ,
73- serverless_inference_config = None ,
74- sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
75- config_name = None ,
76- model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
55+ framework ,
56+ region ,
57+ version = None ,
58+ py_version = None ,
59+ instance_type = None ,
60+ accelerator_type = None ,
61+ image_scope = None ,
62+ container_version = None ,
63+ distribution = None ,
64+ base_framework_version = None ,
65+ training_compiler_config = None ,
66+ model_id = None ,
67+ model_version = None ,
68+ hub_arn = None ,
69+ tolerate_vulnerable_model = False ,
70+ tolerate_deprecated_model = False ,
71+ sdk_version = None ,
72+ inference_tool = None ,
73+ serverless_inference_config = None ,
74+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
75+ config_name = None ,
76+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
7777) -> str :
7878 """Retrieves the ECR URI for the Docker image matching the given arguments.
7979
@@ -250,10 +250,10 @@ def retrieve(
250250 if config .get ("version_aliases" ).get (original_version ):
251251 _version = config .get ("version_aliases" )[original_version ]
252252 if (
253- config .get ("versions" , {})
254- .get (_version , {})
255- .get ("version_aliases" , {})
256- .get (base_framework_version , {})
253+ config .get ("versions" , {})
254+ .get (_version , {})
255+ .get ("version_aliases" , {})
256+ .get (base_framework_version , {})
257257 ):
258258 _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
259259 base_framework_version
@@ -290,16 +290,16 @@ def retrieve(
290290
291291
292292def _get_image_tag (
293- container_version ,
294- distribution ,
295- final_image_scope ,
296- framework ,
297- inference_tool ,
298- instance_type ,
299- processor ,
300- py_version ,
301- tag_prefix ,
302- version ,
293+ container_version ,
294+ distribution ,
295+ final_image_scope ,
296+ framework ,
297+ inference_tool ,
298+ instance_type ,
299+ processor ,
300+ py_version ,
301+ tag_prefix ,
302+ version ,
303303):
304304 """Return image tag based on framework, container, and compute configuration(s)."""
305305 instance_type_family = utils .get_instance_type_family (instance_type )
@@ -311,8 +311,8 @@ def _get_image_tag(
311311 "instance type" ,
312312 )
313313 if (
314- instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315- or final_image_scope == INFERENCE_GRAVITON
314+ instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315+ or final_image_scope == INFERENCE_GRAVITON
316316 ):
317317 version_to_arm64_tag_mapping = {
318318 "xgboost" : {
@@ -330,7 +330,7 @@ def _get_image_tag(
330330 tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
331331
332332 if instance_type is not None and _should_auto_select_container_version (
333- instance_type , distribution
333+ instance_type , distribution
334334 ):
335335 container_versions = {
336336 "tensorflow-2.3-gpu-py37" : "cu110-ubuntu18.04-v3" ,
@@ -398,7 +398,7 @@ def _validate_instance_deprecation(framework, instance_type, version):
398398 """Check if instance type is deprecated for a certain framework with a certain version"""
399399 if utils .get_instance_type_family (instance_type ) == "p2" :
400400 if (framework == "pytorch" and Version (version ) >= Version ("1.13" )) or (
401- framework == "tensorflow" and Version (version ) >= Version ("2.12" )
401+ framework == "tensorflow" and Version (version ) >= Version ("2.12" )
402402 ):
403403 raise ValueError (
404404 "P2 instances have been deprecated for sagemaker jobs starting PyTorch 1.13 and TensorFlow 2.12"
@@ -411,17 +411,17 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
411411 """Validate if framework is supported for the instance_type"""
412412 # Validate for Trainium allowed frameworks
413413 if (
414- instance_type is not None
415- and "trn" in instance_type
416- and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
414+ instance_type is not None
415+ and "trn" in instance_type
416+ and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
417417 ):
418418 _validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium" )
419419
420420 # Validate for Graviton allowed frameowrks
421421 if (
422- instance_type is not None
423- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424- and framework not in GRAVITON_ALLOWED_FRAMEWORKS
422+ instance_type is not None
423+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424+ and framework not in GRAVITON_ALLOWED_FRAMEWORKS
425425 ):
426426 _validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
427427
@@ -436,8 +436,8 @@ def config_for_framework(framework):
436436def _get_final_image_scope (framework , instance_type , image_scope ):
437437 """Return final image scope based on provided framework and instance type."""
438438 if (
439- framework in GRAVITON_ALLOWED_FRAMEWORKS
440- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
439+ framework in GRAVITON_ALLOWED_FRAMEWORKS
440+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
441441 ):
442442 return INFERENCE_GRAVITON
443443 if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -635,16 +635,16 @@ def _format_tag(tag_prefix, processor, py_version, container_version, inference_
635635
636636@override_pipeline_parameter_var
637637def get_training_image_uri (
638- region ,
639- framework ,
640- framework_version = None ,
641- py_version = None ,
642- image_uri = None ,
643- distribution = None ,
644- compiler_config = None ,
645- tensorflow_version = None ,
646- pytorch_version = None ,
647- instance_type = None ,
638+ region ,
639+ framework ,
640+ framework_version = None ,
641+ py_version = None ,
642+ image_uri = None ,
643+ distribution = None ,
644+ compiler_config = None ,
645+ tensorflow_version = None ,
646+ pytorch_version = None ,
647+ instance_type = None ,
648648) -> str :
649649 """Retrieves the image URI for training.
650650
@@ -746,3 +746,141 @@ def get_base_python_image_uri(region, py_version="310") -> str:
746746 repo_and_tag = repo + ":" + version
747747
748748 return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo_and_tag )
749+
750+
751+ def get_latest_container_image (framework : str ,
752+ image_scope : Optional [str ] = None ,
753+ instance_type : Optional [str ] = None ,
754+ py_version : Optional [str ] = None ,
755+ region : str = "us-west-2" ,
756+ version : Optional [str ] = None ,
757+ accelerator_type = None ,
758+ container_version = None ,
759+ distribution = None ,
760+ base_framework_version = None ,
761+ training_compiler_config = None ,
762+ model_id = None ,
763+ model_version = None ,
764+ hub_arn = None ,
765+ sdk_version = None ,
766+ inference_tool = None ,
767+ serverless_inference_config = None ,
768+ config_name = None ,
769+ ) -> Tuple [str , str ]:
770+ """Retrieves the latest container image URI
771+ Args:
772+ framework (str): The name of the framework or algorithm.
773+ image_scope (str): The image type, i.e. what it is used for.
774+ Valid values: "training", "inference", "inference_graviton", "eia".
775+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
776+ region (str): The AWS region.
777+ version (str): The framework or algorithm version. This is required if there is
778+ more than one supported version for the given framework or algorithm.
779+ py_version (str): The Python version. This is required if there is
780+ more than one supported Python version for the given framework version.
781+ instance_type (str): The SageMaker instance type. For supported types, see
782+ https://aws.amazon.com/sagemaker/pricing. This is required if
783+ there are different images for different processor types.
784+ accelerator_type (str): Elastic Inference accelerator type. For more, see
785+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
786+ container_version (str): the version of docker image.
787+ Ideally the value of parameter should be created inside the framework.
788+ For custom use, see the list of supported container versions:
789+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
790+ (default: None).
791+ distribution (dict): A dictionary with information on how to run distributed training
792+ training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
793+ A configuration class for the SageMaker Training Compiler
794+ (default: None).
795+ model_id (str): The JumpStart model ID for which to retrieve the image URI
796+ (default: None).
797+ model_version (str): The version of the JumpStart model for which to retrieve the
798+ image URI (default: None).
799+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
800+ model details from. (Default: None).
801+ sdk_version (str): the version of python-sdk that will be used in the image retrieval.
802+ (default: None).
803+ inference_tool (str): the tool that will be used to aid in the inference.
804+ Valid values: "neuron, neuronx, None"
805+ (default: None).
806+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
807+ Specifies configuration related to serverless endpoint. Instance type is
808+ not provided in serverless inference. So this is used to determine processor type.
809+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
810+ """
811+ try :
812+ framework_config = config_for_framework (framework )
813+ except FileNotFoundError :
814+ raise ValueError ("Invalid framework {}" .format (framework ))
815+
816+ if not framework_config :
817+ raise ValueError ("Invalid framework {}" .format (framework ))
818+
819+ if not version :
820+ version = _fetch_latest_version_from_config (framework_config , image_scope )
821+ image_uri = retrieve (framework = framework ,
822+ region = region ,
823+ version = version ,
824+ instance_type = instance_type ,
825+ py_version = py_version ,
826+ accelerator_type = accelerator_type ,
827+ image_scope = image_scope ,
828+ container_version = container_version ,
829+ distribution = distribution ,
830+ base_framework_version = base_framework_version ,
831+ training_compiler_config = training_compiler_config ,
832+ model_id = model_id ,
833+ model_version = model_version ,
834+ hub_arn = hub_arn ,
835+ sdk_version = sdk_version ,
836+ inference_tool = inference_tool ,
837+ serverless_inference_config = serverless_inference_config ,
838+ config_name = config_name
839+ )
840+ return image_uri , version
841+
842+
843+ def _fetch_latest_version_from_config (framework_config : dict ,
844+ image_scope : Optional [str ] = None ) -> Optional [str ]:
845+ """ Helper function to fetch the latest version as a string from a framework's config
846+ Args:
847+ framework_config (dict): A framework config dict.
848+ image_scope (str): Scope of the image, eg: training, inference
849+ Returns:
850+ Version string if latest version found else None
851+ """
852+ if image_scope in framework_config :
853+ if image_scope_config := framework_config [image_scope ]:
854+ if "version_aliases" in image_scope_config :
855+ if "latest" in image_scope_config ["version_aliases" ]:
856+ return image_scope_config ["version_aliases" ]["latest" ]
857+ top_version = None
858+ bottom_version = None
859+
860+ if "versions" in framework_config :
861+ versions = list (framework_config ["versions" ].keys ())
862+ top_version = versions [0 ]
863+ bottom_version = versions [- 1 ]
864+ if top_version == "latest" or bottom_version == "latest" :
865+ return None
866+ elif (image_scope is not None and image_scope in framework_config
867+ and "versions" in framework_config [image_scope ]):
868+ versions = list (framework_config [image_scope ]["versions" ].keys ())
869+ top_version = versions [0 ]
870+ bottom_version = versions [- 1 ]
871+ elif "processing" in framework_config and "versions" in framework_config ["processing" ]:
872+ versions = list (framework_config ["processing" ]["versions" ].keys ())
873+ top_version = versions [0 ]
874+ bottom_version = versions [- 1 ]
875+
876+ if top_version and bottom_version :
877+ if top_version .endswith (".x" ) or bottom_version .endswith (".x" ):
878+ top_number = int (top_version [:- 2 ])
879+ bottom_number = int (bottom_version [:- 2 ])
880+ max_version = max (top_number , bottom_number )
881+ return f"{ max_version } .x"
882+ if Version (top_version ) >= Version (bottom_version ):
883+ return top_version
884+ return bottom_version
885+
886+ return None
0 commit comments