@@ -39,6 +39,7 @@ def _retrieve_default_environment_variables(
3939 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
4040 instance_type : Optional [str ] = None ,
4141 script : JumpStartScriptScope = JumpStartScriptScope .INFERENCE ,
42+ config_name : Optional [str ] = None ,
4243) -> Dict [str , str ]:
4344 """Retrieves the inference environment variables for the model matching the given arguments.
4445
@@ -68,6 +69,7 @@ def _retrieve_default_environment_variables(
6869 environment variables specific for the instance type.
6970 script (JumpStartScriptScope): The JumpStart script for which to retrieve
7071 environment variables.
72+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
7173 Returns:
7274 dict: the inference environment variables to use for the model.
7375 """
@@ -84,6 +86,7 @@ def _retrieve_default_environment_variables(
8486 tolerate_vulnerable_model = tolerate_vulnerable_model ,
8587 tolerate_deprecated_model = tolerate_deprecated_model ,
8688 sagemaker_session = sagemaker_session ,
89+ config_name = config_name ,
8790 )
8891
8992 default_environment_variables : Dict [str , str ] = {}
@@ -121,6 +124,7 @@ def _retrieve_default_environment_variables(
121124 tolerate_deprecated_model = tolerate_deprecated_model ,
122125 sagemaker_session = sagemaker_session ,
123126 instance_type = instance_type ,
127+ config_name = config_name ,
124128 )
125129 )
126130
@@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value(
167171 tolerate_deprecated_model : bool = False ,
168172 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
169173 instance_type : Optional [str ] = None ,
174+ config_name : Optional [str ] = None ,
170175) -> Optional [str ]:
171176 """Retrieves the gated model env var URI matching the given arguments.
172177
@@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value(
190195 chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
191196 instance_type (str): An instance type to optionally supply in order to get
192197 environment variables specific for the instance type.
198+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
193199
194200 Returns:
195201 Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
@@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value(
211217 tolerate_vulnerable_model = tolerate_vulnerable_model ,
212218 tolerate_deprecated_model = tolerate_deprecated_model ,
213219 sagemaker_session = sagemaker_session ,
220+ config_name = config_name ,
214221 )
215222
216223 s3_key : Optional [str ] = (
0 commit comments