1313"""ImageSpec class module."""
1414from __future__ import absolute_import
1515
16+ import re
17+ from enum import Enum
1618from typing import Optional
1719
18- from sagemaker import image_uris , Session
19- from sagemaker .serverless import ServerlessInferenceConfig
20- from sagemaker .training_compiler .config import TrainingCompilerConfig
20+ from sagemaker import utils
21+ from sagemaker .image_uris import _validate_version_and_set_if_needed , _version_for_config , \
22+ _config_for_framework_and_scope , _validate_py_version_and_set_if_needed , _registry_from_region , ECR_URI_TEMPLATE , \
23+ _get_latest_versions , _validate_instance_deprecation , _get_image_tag , _validate_arg
24+ from packaging .version import Version
25+
26+ DEFAULT_TOLERATE_MODEL = False
27+
28+
29+ class Framework (Enum ):
30+ HUGGING_FACE = "huggingface"
31+ HUGGING_FACE_NEURON = "huggingface-neuron"
32+ HUGGING_FACE_NEURON_X = "huggingface-neuronx"
33+ HUGGING_FACE_LLM = "huggingface-llm"
34+ HUGGING_FACE_TEI_GPU = "huggingface-tei"
35+ HUGGING_FACE_TEI_CPU = "huggingface-tei-cpu"
36+ HUGGING_FACE_LLM_NEURONX = "huggingface-llm-neuronx"
37+ HUGGING_FACE_TRAINING_COMPILER = "huggingface-training-compiler"
38+ XGBOOST = "xgboost"
39+ XG_BOOST_NEO = "xg-boost-neo"
40+ SKLEARN = "sklearn"
41+ PYTORCH = "pytorch"
42+ PYTORCH_TRAINING_COMPILER = "pytorch-training-compiler"
43+ DATA_WRANGLER = "data-wrangler"
44+ STABILITYAI = "stabilityai"
45+ SAGEMAKER_TRITONSERVER = "sagemaker-tritonserver"
46+
47+
48+ class ImageScope (Enum ):
49+ TRAINING = "training"
50+ INFERENCE = "inference"
51+ INFERENCE_GRAVITON = "inference-graviton"
52+
53+
54+ class Processor (Enum ):
55+ INF = "inf"
56+ NEURON = "neuron"
57+ GPU = "gpu"
58+ CPU = "cpu"
59+ TRN = "trn"
2160
2261
2362class ImageSpec :
2463 """ImageSpec class to get image URI for a specific framework version."""
2564
26- def __init__ (
27- self ,
28- framework_name : str ,
29- version : str ,
30- image_scope : Optional [str ] = None ,
31- instance_type : Optional [str ] = None ,
32- py_version : Optional [str ] = None ,
33- region : Optional [str ] = "us-west-2" ,
34- accelerator_type : Optional [str ] = None ,
35- container_version : Optional [str ] = None ,
36- distribution : Optional [dict ] = None ,
37- base_framework_version : Optional [str ] = None ,
38- training_compiler_config : Optional [TrainingCompilerConfig ] = None ,
39- model_id : Optional [str ] = None ,
40- model_version : Optional [str ] = None ,
41- hub_arn : Optional [str ] = None ,
42- tolerate_vulnerable_model : Optional [bool ] = False ,
43- tolerate_deprecated_model : Optional [bool ] = False ,
44- sdk_version : Optional [str ] = None ,
45- inference_tool : Optional [str ] = None ,
46- serverless_inference_config : Optional [ServerlessInferenceConfig ] = None ,
47- config_name : Optional [str ] = None ,
48- sagemaker_session : Optional [Session ] = None ,
49- ):
50- self .framework_name = framework_name
65+ def __init__ (self ,
66+ framework : Framework ,
67+ processor : Optional [Processor ] = Processor .CPU ,
68+ region : Optional [str ] = "us-west-2" ,
69+ version = None ,
70+ py_version = None ,
71+ instance_type = None ,
72+ accelerator_type = None ,
73+ image_scope : ImageScope = ImageScope .TRAINING ,
74+ container_version = None ,
75+ distribution = None ,
76+ base_framework_version = None ,
77+ sdk_version = None ,
78+ inference_tool = None ):
79+ self .framework = framework
80+ self .processor = processor
5181 self .version = version
5282 self .image_scope = image_scope
5383 self .instance_type = instance_type
@@ -57,45 +87,175 @@ def __init__(
5787 self .container_version = container_version
5888 self .distribution = distribution
5989 self .base_framework_version = base_framework_version
60- self .training_compiler_config = training_compiler_config
61- self .model_id = model_id
62- self .model_version = model_version
63- self .hub_arn = hub_arn
64- self .tolerate_vulnerable_model = tolerate_vulnerable_model
65- self .tolerate_deprecated_model = tolerate_deprecated_model
6690 self .sdk_version = sdk_version
6791 self .inference_tool = inference_tool
68- self .serverless_inference_config = serverless_inference_config
69- self .config_name = config_name
70- self .sagemaker_session = sagemaker_session
71-
72- def get_image_uri (
73- self , image_scope : Optional [str ] = None , instance_type : Optional [str ] = None
74- ) -> str :
75- """Get image URI for a specific framework version."""
76-
77- self .image_scope = image_scope or self .image_scope
78- self .instance_type = instance_type or self .instance_type
79- return image_uris .retrieve (
80- framework = self .framework_name ,
81- image_scope = self .image_scope ,
82- instance_type = self .instance_type ,
83- py_version = self .py_version ,
84- region = self .region ,
85- version = self .version ,
86- accelerator_type = self .accelerator_type ,
87- container_version = self .container_version ,
88- distribution = self .distribution ,
89- base_framework_version = self .base_framework_version ,
90- training_compiler_config = self .training_compiler_config ,
91- model_id = self .model_id ,
92- model_version = self .model_version ,
93- hub_arn = self .hub_arn ,
94- tolerate_vulnerable_model = self .tolerate_vulnerable_model ,
95- tolerate_deprecated_model = self .tolerate_deprecated_model ,
96- sdk_version = self .sdk_version ,
97- inference_tool = self .inference_tool ,
98- serverless_inference_config = self .serverless_inference_config ,
99- config_name = self .config_name ,
100- sagemaker_session = self .sagemaker_session ,
101- )
92+
93+ def update_image_spec (self , ** kwargs ):
94+ for key , value in kwargs .items ():
95+ if hasattr (self , key ):
96+ setattr (self , key , value )
97+
98+ def retrieve (self ) -> str :
99+ """Retrieves the ECR URI for the Docker image matching the given arguments.
100+
101+ Ideally this function should not be called directly, rather it should be called from the
102+ fit() function inside framework estimator.
103+
104+ Args:
105+ framework (Framework): The name of the framework or algorithm.
106+ processor (Processor): The name of the processor (CPU, GPU, etc.).
107+ region (str): The AWS region.
108+ version (str): The framework or algorithm version. This is required if there is
109+ more than one supported version for the given framework or algorithm.
110+ py_version (str): The Python version. This is required if there is
111+ more than one supported Python version for the given framework version.
112+ instance_type (str): The SageMaker instance type. For supported types, see
113+ https://aws.amazon.com/sagemaker/pricing. This is required if
114+ there are different images for different processor types.
115+ accelerator_type (str): Elastic Inference accelerator type. For more, see
116+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
117+ image_scope (str): The image type, i.e. what it is used for.
118+ Valid values: "training", "inference", "inference_graviton", "eia".
119+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
120+ container_version (str): the version of docker image.
121+ Ideally the value of parameter should be created inside the framework.
122+ For custom use, see the list of supported container versions:
123+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
124+ (default: None).
125+ distribution (dict): A dictionary with information on how to run distributed training
126+ sdk_version (str): the version of python-sdk that will be used in the image retrieval.
127+ (default: None).
128+ inference_tool (str): the tool that will be used to aid in the inference.
129+ Valid values: "neuron, neuronx, None"
130+ (default: None).
131+
132+ Returns:
133+ str: The ECR URI for the corresponding SageMaker Docker image.
134+
135+ Raises:
136+ NotImplementedError: If the scope is not supported.
137+ ValueError: If the combination of arguments specified is not supported or
138+ any PipelineVariable object is passed in.
139+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
140+ known security vulnerabilities.
141+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
142+ """
143+ config = _config_for_framework_and_scope (self .framework .value ,
144+ self .image_scope .value ,
145+ self .accelerator_type )
146+
147+ original_version = self .version
148+ try :
149+ version = _validate_version_and_set_if_needed (self .version , config , self .framework .value )
150+ except ValueError :
151+ version = None
152+ if not version :
153+ version = self ._fetch_latest_version_from_config (config )
154+
155+ version_config = config ["versions" ][_version_for_config (version , config )]
156+
157+ if "huggingface" in self .framework .value :
158+ if version_config .get ("version_aliases" ):
159+ full_base_framework_version = version_config ["version_aliases" ].get (
160+ self .base_framework_version , self .base_framework_version
161+ )
162+ _validate_arg (full_base_framework_version , list (version_config .keys ()), "base framework" )
163+ version_config = version_config .get (full_base_framework_version )
164+
165+ self .py_version = _validate_py_version_and_set_if_needed (self .py_version ,
166+ version_config ,
167+ self .framework .value )
168+ version_config = version_config .get (self .py_version ) or version_config
169+
170+ registry = _registry_from_region (self .region , version_config ["registries" ])
171+ endpoint_data = utils ._botocore_resolver ().construct_endpoint ("ecr" , self .region )
172+ if self .region == "il-central-1" and not endpoint_data :
173+ endpoint_data = {"hostname" : "ecr.{}.amazonaws.com" .format (self .region )}
174+ hostname = endpoint_data ["hostname" ]
175+
176+ repo = version_config ["repository" ]
177+
178+ # if container version is available in .json file, utilize that
179+ if version_config .get ("container_version" ):
180+ self .container_version = version_config ["container_version" ][self .processor .value ]
181+
182+ # Append sdk version in case of trainium instances
183+ if repo in ["pytorch-training-neuron" ]:
184+ if not self .sdk_version :
185+ sdk_version = _get_latest_versions (version_config ["sdk_versions" ])
186+ self .container_version = self .sdk_version + "-" + self .container_version
187+
188+ if self .framework == Framework .HUGGING_FACE :
189+ pt_or_tf_version = (
190+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (self .base_framework_version ).group (2 )
191+ )
192+ _version = original_version
193+
194+ if repo in [
195+ "huggingface-pytorch-trcomp-training" ,
196+ "huggingface-tensorflow-trcomp-training" ,
197+ ]:
198+ _version = version
199+ if repo in [
200+ "huggingface-pytorch-inference-neuron" ,
201+ "huggingface-pytorch-inference-neuronx" ,
202+ ]:
203+ if not sdk_version :
204+ self .sdk_version = _get_latest_versions (version_config ["sdk_versions" ])
205+ self .container_version = self .sdk_version + "-" + self .container_version
206+ if config .get ("version_aliases" ).get (original_version ):
207+ _version = config .get ("version_aliases" )[original_version ]
208+ if (
209+ config .get ("versions" , {})
210+ .get (_version , {})
211+ .get ("version_aliases" , {})
212+ .get (self .base_framework_version , {})
213+ ):
214+ _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
215+ self .base_framework_version
216+ ]
217+ pt_or_tf_version = (
218+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (_base_framework_version ).group (2 )
219+ )
220+
221+ tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
222+ else :
223+ tag_prefix = version_config .get ("tag_prefix" , version )
224+
225+ if repo == f"{ self .framework .value } -inference-graviton" :
226+ self .container_version = f"{ self .container_version } -sagemaker"
227+ _validate_instance_deprecation (self .framework ,
228+ self .instance_type ,
229+ version )
230+
231+ tag = _get_image_tag (
232+ self .container_version ,
233+ self .distribution ,
234+ self .image_scope .value ,
235+ self .framework ,
236+ self .inference_tool ,
237+ self .instance_type ,
238+ self .processor .value ,
239+ self .py_version ,
240+ tag_prefix ,
241+ version )
242+
243+ if tag :
244+ repo += ":{}" .format (tag )
245+
246+ return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
247+
248+ def _fetch_latest_version_from_config (self ,
249+ framework_config : dict ) -> str :
250+ if self .image_scope .value in framework_config :
251+ if image_scope_config := framework_config [self .image_scope .value ]:
252+ if version_aliases := image_scope_config ["version_aliases" ]:
253+ if latest_version := version_aliases ["latest" ]:
254+ return latest_version
255+ versions = list (framework_config ["versions" ].keys ())
256+ top_version = versions [0 ]
257+ bottom_version = versions [- 1 ]
258+
259+ if Version (top_version ) >= Version (bottom_version ):
260+ return top_version
261+ return bottom_version
0 commit comments