1717from typing import Optional , Tuple
1818from sagemaker .jumpstart .constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1919
20- from sagemaker .jumpstart .utils import get_jumpstart_model_id_version_from_resource_arn
20+ from sagemaker .jumpstart .utils import get_jumpstart_model_info_from_resource_arn
2121from sagemaker .session import Session
2222from sagemaker .utils import aws_partition
2323
@@ -26,7 +26,7 @@ def get_model_info_from_endpoint(
2626 endpoint_name : str ,
2727 inference_component_name : Optional [str ] = None ,
2828 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
29- ) -> Tuple [str , str , Optional [str ], Optional [str ]]:
29+ ) -> Tuple [str , str , Optional [str ], Optional [str ], Optional [ str ] ]:
3030 """Optionally inference component names, return the model ID, version and config name.
3131
3232 Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
@@ -46,7 +46,8 @@ def get_model_info_from_endpoint(
4646 (
4747 model_id ,
4848 model_version ,
49- config_name ,
49+ inference_config_name ,
50+ training_config_name ,
5051 ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name ( # noqa E501 # pylint: disable=c0301
5152 inference_component_name , sagemaker_session
5253 )
@@ -55,17 +56,29 @@ def get_model_info_from_endpoint(
5556 (
5657 model_id ,
5758 model_version ,
58- config_name ,
59+ inference_config_name ,
60+ training_config_name ,
5961 inference_component_name ,
6062 ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name ( # noqa E501 # pylint: disable=c0301
6163 endpoint_name , sagemaker_session
6264 )
6365
6466 else :
65- model_id , model_version , config_name = _get_model_info_from_model_based_endpoint (
67+ (
68+ model_id ,
69+ model_version ,
70+ inference_config_name ,
71+ training_config_name ,
72+ ) = _get_model_info_from_model_based_endpoint (
6673 endpoint_name , inference_component_name , sagemaker_session
6774 )
68- return model_id , model_version , inference_component_name , config_name
75+ return (
76+ model_id ,
77+ model_version ,
78+ inference_component_name ,
79+ inference_config_name ,
80+ training_config_name ,
81+ )
6982
7083
7184def _get_model_info_from_inference_component_endpoint_without_inference_component_name (
@@ -125,9 +138,12 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
125138 f"inference-component/{ inference_component_name } "
126139 )
127140
128- model_id , model_version , config_name = get_jumpstart_model_id_version_from_resource_arn (
129- inference_component_arn , sagemaker_session
130- )
141+ (
142+ model_id ,
143+ model_version ,
144+ inference_config_name ,
145+ training_config_name ,
146+ ) = get_jumpstart_model_info_from_resource_arn (inference_component_arn , sagemaker_session )
131147
132148 if not model_id :
133149 raise ValueError (
@@ -136,14 +152,14 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
136152 "when retrieving default predictor for this inference component."
137153 )
138154
139- return model_id , model_version , config_name
155+ return model_id , model_version , inference_config_name , training_config_name
140156
141157
142158def _get_model_info_from_model_based_endpoint (
143159 endpoint_name : str ,
144160 inference_component_name : Optional [str ],
145161 sagemaker_session : Session ,
146- ) -> Tuple [str , str , Optional [str ]]:
162+ ) -> Tuple [str , str , Optional [str ], Optional [ str ] ]:
147163 """Returns the model ID, version and config name inferred from a model-based endpoint.
148164
149165 Raises:
@@ -163,9 +179,12 @@ def _get_model_info_from_model_based_endpoint(
163179
164180 endpoint_arn = f"arn:{ partition } :sagemaker:{ region } :{ account_id } :endpoint/{ endpoint_name } "
165181
166- model_id , model_version , config_name = get_jumpstart_model_id_version_from_resource_arn (
167- endpoint_arn , sagemaker_session
168- )
182+ (
183+ model_id ,
184+ model_version ,
185+ inference_config_name ,
186+ training_config_name ,
187+ ) = get_jumpstart_model_info_from_resource_arn (endpoint_arn , sagemaker_session )
169188
170189 if not model_id :
171190 raise ValueError (
@@ -174,13 +193,13 @@ def _get_model_info_from_model_based_endpoint(
174193 "predictor for this endpoint."
175194 )
176195
177- return model_id , model_version , config_name
196+ return model_id , model_version , inference_config_name , training_config_name
178197
179198
180199def get_model_info_from_training_job (
181200 training_job_name : str ,
182201 sagemaker_session : Optional [Session ] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
183- ) -> Tuple [str , str , Optional [str ]]:
202+ ) -> Tuple [str , str , Optional [str ], Optional [ str ] ]:
184203 """Returns the model ID and version and config name inferred from a training job.
185204
186205 Raises:
@@ -199,8 +218,9 @@ def get_model_info_from_training_job(
199218 (
200219 model_id ,
201220 inferred_model_version ,
202- config_name ,
203- ) = get_jumpstart_model_id_version_from_resource_arn (training_job_arn , sagemaker_session )
221+ inference_config_name ,
222+ trainig_config_name ,
223+ ) = get_jumpstart_model_info_from_resource_arn (training_job_arn , sagemaker_session )
204224
205225 model_version = inferred_model_version or None
206226
@@ -211,4 +231,4 @@ def get_model_info_from_training_job(
211231 "for this training job."
212232 )
213233
214- return model_id , model_version , config_name
234+ return model_id , model_version , inference_config_name , trainig_config_name
0 commit comments