2222from sagemaker .utils import aws_partition
2323
2424
25- def get_model_id_version_from_endpoint (
25+ def get_model_info_from_endpoint (
2626 endpoint_name : str ,
2727 inference_component_name : Optional [str ] = None ,
2828 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
29- ) -> Tuple [str , str , Optional [str ]]:
30- """Given an endpoint and optionally inference component names, return the model ID and version .
29+ ) -> Tuple [str , str , Optional [str ], Optional [ str ] ]:
30+ """Optionally inference component names, return the model ID, version and config name .
3131
3232 Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
3333 and version. A third string element is included in the tuple for any inferred inference
@@ -46,30 +46,32 @@ def get_model_id_version_from_endpoint(
4646 (
4747 model_id ,
4848 model_version ,
49- ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name ( # noqa E501 # pylint: disable=c0301
49+ config_name ,
50+ ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name ( # noqa E501 # pylint: disable=c0301
5051 inference_component_name , sagemaker_session
5152 )
5253
5354 else :
5455 (
5556 model_id ,
5657 model_version ,
58+ config_name ,
5759 inference_component_name ,
58- ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name ( # noqa E501 # pylint: disable=c0301
60+ ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name ( # noqa E501 # pylint: disable=c0301
5961 endpoint_name , sagemaker_session
6062 )
6163
6264 else :
63- model_id , model_version = _get_model_id_version_from_model_based_endpoint (
65+ model_id , model_version , config_name = _get_model_info_from_model_based_endpoint (
6466 endpoint_name , inference_component_name , sagemaker_session
6567 )
66- return model_id , model_version , inference_component_name
68+ return model_id , model_version , inference_component_name , config_name
6769
6870
69- def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name (
71+ def _get_model_info_from_inference_component_endpoint_without_inference_component_name (
7072 endpoint_name : str , sagemaker_session : Session
71- ) -> Tuple [str , str , str ]:
72- """Given an endpoint name, derives the model ID, version, and inferred inference component name.
73+ ) -> Tuple [str , str , str , str ]:
74+ """Derives the model ID, version, config name and inferred inference component name.
7375
7476 This function assumes the endpoint corresponds to an inference-component-based endpoint.
7577 An endpoint is inference-component-based if and only if the associated endpoint config
@@ -98,14 +100,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co
98100 )
99101 inference_component_name = inference_component_names [0 ]
100102 return (
101- * _get_model_id_version_from_inference_component_endpoint_with_inference_component_name (
103+ * _get_model_info_from_inference_component_endpoint_with_inference_component_name (
102104 inference_component_name , sagemaker_session
103105 ),
104106 inference_component_name ,
105107 )
106108
107109
108- def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name (
110+ def _get_model_info_from_inference_component_endpoint_with_inference_component_name (
109111 inference_component_name : str , sagemaker_session : Session
110112):
111113 """Returns the model ID and version inferred from a SageMaker inference component.
@@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
123125 f"inference-component/{ inference_component_name } "
124126 )
125127
126- model_id , model_version = get_jumpstart_model_id_version_from_resource_arn (
128+ model_id , model_version , config_name = get_jumpstart_model_id_version_from_resource_arn (
127129 inference_component_arn , sagemaker_session
128130 )
129131
@@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
134136 "when retrieving default predictor for this inference component."
135137 )
136138
137- return model_id , model_version
139+ return model_id , model_version , config_name
138140
139141
140- def _get_model_id_version_from_model_based_endpoint (
142+ def _get_model_info_from_model_based_endpoint (
141143 endpoint_name : str ,
142144 inference_component_name : Optional [str ],
143145 sagemaker_session : Session ,
144- ) -> Tuple [str , str ]:
145- """Returns the model ID and version inferred from a model-based endpoint.
146+ ) -> Tuple [str , str , Optional [ str ] ]:
147+ """Returns the model ID, version and config name inferred from a model-based endpoint.
146148
147149 Raises:
148150 ValueError: If an inference component name is supplied, or if the endpoint does
@@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint(
161163
162164 endpoint_arn = f"arn:{ partition } :sagemaker:{ region } :{ account_id } :endpoint/{ endpoint_name } "
163165
164- model_id , model_version = get_jumpstart_model_id_version_from_resource_arn (
166+ model_id , model_version , config_name = get_jumpstart_model_id_version_from_resource_arn (
165167 endpoint_arn , sagemaker_session
166168 )
167169
@@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint(
172174 "predictor for this endpoint."
173175 )
174176
175- return model_id , model_version
177+ return model_id , model_version , config_name
176178
177179
178- def get_model_id_version_from_training_job (
180+ def get_model_info_from_training_job (
179181 training_job_name : str ,
180182 sagemaker_session : Optional [Session ] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
181- ) -> Tuple [str , str ]:
182- """Returns the model ID and version inferred from a training job.
183+ ) -> Tuple [str , str , Optional [ str ] ]:
184+ """Returns the model ID and version and config name inferred from a training job.
183185
184186 Raises:
185187 ValueError: If the training job does not have tags from which the model ID
@@ -194,9 +196,11 @@ def get_model_id_version_from_training_job(
194196 f"arn:{ partition } :sagemaker:{ region } :{ account_id } :training-job/{ training_job_name } "
195197 )
196198
197- model_id , inferred_model_version = get_jumpstart_model_id_version_from_resource_arn (
198- training_job_arn , sagemaker_session
199- )
199+ (
200+ model_id ,
201+ inferred_model_version ,
202+ config_name ,
203+ ) = get_jumpstart_model_id_version_from_resource_arn (training_job_arn , sagemaker_session )
200204
201205 model_version = inferred_model_version or None
202206
@@ -207,4 +211,4 @@ def get_model_id_version_from_training_job(
207211 "for this training job."
208212 )
209213
210- return model_id , model_version
214+ return model_id , model_version , config_name
0 commit comments