1616import re
1717from enum import Enum
1818from typing import Optional
19+ from packaging .version import Version
1920
2021from sagemaker import utils
21- from sagemaker .image_uris import _validate_version_and_set_if_needed , _version_for_config , \
22- _config_for_framework_and_scope , _validate_py_version_and_set_if_needed , _registry_from_region , ECR_URI_TEMPLATE , \
23- _get_latest_versions , _validate_instance_deprecation , _get_image_tag , _validate_arg
24- from packaging .version import Version
22+ from sagemaker .image_uris import (
23+ _validate_version_and_set_if_needed ,
24+ _version_for_config ,
25+ _config_for_framework_and_scope ,
26+ _validate_py_version_and_set_if_needed ,
27+ _registry_from_region ,
28+ ECR_URI_TEMPLATE ,
29+ _get_latest_versions ,
30+ _validate_instance_deprecation ,
31+ _get_image_tag ,
32+ _validate_arg ,
33+ )
2534
2635DEFAULT_TOLERATE_MODEL = False
2736
2837
2938class Framework (Enum ):
39+ """Framework enum class."""
40+
3041 HUGGING_FACE = "huggingface"
3142 HUGGING_FACE_NEURON = "huggingface-neuron"
3243 HUGGING_FACE_NEURON_X = "huggingface-neuronx"
@@ -46,12 +57,16 @@ class Framework(Enum):
4657
4758
4859class ImageScope (Enum ):
60+ """ImageScope enum class."""
61+
4962 TRAINING = "training"
5063 INFERENCE = "inference"
5164 INFERENCE_GRAVITON = "inference-graviton"
5265
5366
5467class Processor (Enum ):
68+ """Processor enum class."""
69+
5570 INF = "inf"
5671 NEURON = "neuron"
5772 GPU = "gpu"
@@ -60,22 +75,53 @@ class Processor(Enum):
6075
6176
6277class ImageSpec :
63- """ImageSpec class to get image URI for a specific framework version."""
64-
65- def __init__ (self ,
66- framework : Framework ,
67- processor : Optional [Processor ] = Processor .CPU ,
68- region : Optional [str ] = "us-west-2" ,
69- version = None ,
70- py_version = None ,
71- instance_type = None ,
72- accelerator_type = None ,
73- image_scope : ImageScope = ImageScope .TRAINING ,
74- container_version = None ,
75- distribution = None ,
76- base_framework_version = None ,
77- sdk_version = None ,
78- inference_tool = None ):
78+ """ImageSpec class to get image URI for a specific framework version.
79+
80+ Attributes:
81+ framework (Framework): The name of the framework or algorithm.
82+ processor (Processor): The name of the processor (CPU, GPU, etc.).
83+ region (str): The AWS region.
84+ version (str): The framework or algorithm version. This is required if there is
85+ more than one supported version for the given framework or algorithm.
86+ py_version (str): The Python version. This is required if there is
87+ more than one supported Python version for the given framework version.
88+ instance_type (str): The SageMaker instance type. For supported types, see
89+ https://aws.amazon.com/sagemaker/pricing. This is required if
90+ there are different images for different processor types.
91+ accelerator_type (str): Elastic Inference accelerator type. For more, see
92+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
93+ image_scope (str): The image type, i.e. what it is used for.
94+ Valid values: "training", "inference", "inference_graviton", "eia".
95+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
96+ container_version (str): the version of docker image.
97+ Ideally the value of parameter should be created inside the framework.
98+ For custom use, see the list of supported container versions:
99+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
100+ (default: None).
101+ distribution (dict): A dictionary with information on how to run distributed training
102+ sdk_version (str): the version of python-sdk that will be used in the image retrieval.
103+ (default: None).
104+ inference_tool (str): the tool that will be used to aid in the inference.
105+ Valid values: "neuron, neuronx, None"
106+ (default: None).
107+ """
108+
109+ def __init__ (
110+ self ,
111+ framework : Framework ,
112+ processor : Optional [Processor ] = Processor .CPU ,
113+ region : Optional [str ] = "us-west-2" ,
114+ version = None ,
115+ py_version = None ,
116+ instance_type = None ,
117+ accelerator_type = None ,
118+ image_scope : ImageScope = ImageScope .TRAINING ,
119+ container_version = None ,
120+ distribution = None ,
121+ base_framework_version = None ,
122+ sdk_version = None ,
123+ inference_tool = None ,
124+ ):
79125 self .framework = framework
80126 self .processor = processor
81127 self .version = version
@@ -91,44 +137,14 @@ def __init__(self,
91137 self .inference_tool = inference_tool
92138
93139 def update_image_spec (self , ** kwargs ):
140+ """Update the ImageSpec object with the given arguments."""
94141 for key , value in kwargs .items ():
95142 if hasattr (self , key ):
96143 setattr (self , key , value )
97144
98145 def retrieve (self ) -> str :
99146 """Retrieves the ECR URI for the Docker image matching the given arguments.
100147
101- Ideally this function should not be called directly, rather it should be called from the
102- fit() function inside framework estimator.
103-
104- Args:
105- framework (Framework): The name of the framework or algorithm.
106- processor (Processor): The name of the processor (CPU, GPU, etc.).
107- region (str): The AWS region.
108- version (str): The framework or algorithm version. This is required if there is
109- more than one supported version for the given framework or algorithm.
110- py_version (str): The Python version. This is required if there is
111- more than one supported Python version for the given framework version.
112- instance_type (str): The SageMaker instance type. For supported types, see
113- https://aws.amazon.com/sagemaker/pricing. This is required if
114- there are different images for different processor types.
115- accelerator_type (str): Elastic Inference accelerator type. For more, see
116- https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
117- image_scope (str): The image type, i.e. what it is used for.
118- Valid values: "training", "inference", "inference_graviton", "eia".
119- If ``accelerator_type`` is set, ``image_scope`` is ignored.
120- container_version (str): the version of docker image.
121- Ideally the value of parameter should be created inside the framework.
122- For custom use, see the list of supported container versions:
123- https://github.com/aws/deep-learning-containers/blob/master/available_images.md
124- (default: None).
125- distribution (dict): A dictionary with information on how to run distributed training
126- sdk_version (str): the version of python-sdk that will be used in the image retrieval.
127- (default: None).
128- inference_tool (str): the tool that will be used to aid in the inference.
129- Valid values: "neuron, neuronx, None"
130- (default: None).
131-
132148 Returns:
133149 str: The ECR URI for the corresponding SageMaker Docker image.
134150
@@ -140,13 +156,14 @@ def retrieve(self) -> str:
140156 known security vulnerabilities.
141157 DeprecatedJumpStartModelError: If the version of the model is deprecated.
142158 """
143- config = _config_for_framework_and_scope (self .framework .value ,
144- self .image_scope .value ,
145- self .accelerator_type )
146-
159+ config = _config_for_framework_and_scope (
160+ self .framework .value , self .image_scope .value , self .accelerator_type
161+ )
147162 original_version = self .version
148163 try :
149- version = _validate_version_and_set_if_needed (self .version , config , self .framework .value )
164+ version = _validate_version_and_set_if_needed (
165+ self .version , config , self .framework .value
166+ )
150167 except ValueError :
151168 version = None
152169 if not version :
@@ -159,12 +176,14 @@ def retrieve(self) -> str:
159176 full_base_framework_version = version_config ["version_aliases" ].get (
160177 self .base_framework_version , self .base_framework_version
161178 )
162- _validate_arg (full_base_framework_version , list (version_config .keys ()), "base framework" )
179+ _validate_arg (
180+ full_base_framework_version , list (version_config .keys ()), "base framework"
181+ )
163182 version_config = version_config .get (full_base_framework_version )
164183
165- self .py_version = _validate_py_version_and_set_if_needed (self . py_version ,
166- version_config ,
167- self . framework . value )
184+ self .py_version = _validate_py_version_and_set_if_needed (
185+ self . py_version , version_config , self . framework . value
186+ )
168187 version_config = version_config .get (self .py_version ) or version_config
169188
170189 registry = _registry_from_region (self .region , version_config ["registries" ])
@@ -206,16 +225,18 @@ def retrieve(self) -> str:
206225 if config .get ("version_aliases" ).get (original_version ):
207226 _version = config .get ("version_aliases" )[original_version ]
208227 if (
209- config .get ("versions" , {})
210- .get (_version , {})
211- .get ("version_aliases" , {})
212- .get (self .base_framework_version , {})
228+ config .get ("versions" , {})
229+ .get (_version , {})
230+ .get ("version_aliases" , {})
231+ .get (self .base_framework_version , {})
213232 ):
214233 _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
215234 self .base_framework_version
216235 ]
217236 pt_or_tf_version = (
218- re .compile ("^(pytorch|tensorflow)(.*)$" ).match (_base_framework_version ).group (2 )
237+ re .compile ("^(pytorch|tensorflow)(.*)$" )
238+ .match (_base_framework_version )
239+ .group (2 )
219240 )
220241
221242 tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
@@ -224,29 +245,28 @@ def retrieve(self) -> str:
224245
225246 if repo == f"{ self .framework .value } -inference-graviton" :
226247 self .container_version = f"{ self .container_version } -sagemaker"
227- _validate_instance_deprecation (self .framework ,
228- self .instance_type ,
229- version )
248+ _validate_instance_deprecation (self .framework , self .instance_type , version )
230249
231250 tag = _get_image_tag (
232251 self .container_version ,
233252 self .distribution ,
234253 self .image_scope .value ,
235- self .framework ,
254+ self .framework . value ,
236255 self .inference_tool ,
237256 self .instance_type ,
238257 self .processor .value ,
239258 self .py_version ,
240259 tag_prefix ,
241- version )
260+ version ,
261+ )
242262
243263 if tag :
244264 repo += ":{}" .format (tag )
245265
246266 return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
247267
248- def _fetch_latest_version_from_config (self ,
249- framework_config : dict ) -> str :
268+ def _fetch_latest_version_from_config (self , framework_config : dict ) -> str :
269+ """Fetches the latest version from the framework config."""
250270 if self .image_scope .value in framework_config :
251271 if image_scope_config := framework_config [self .image_scope .value ]:
252272 if version_aliases := image_scope_config ["version_aliases" ]:
0 commit comments