Skip to content
24 changes: 24 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,29 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
return kwargs


def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
"""Sets default config name to the kwargs. Returns full kwargs."""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)
if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().resolved_config
):
kwargs.config_name = specs.inference_configs.get_top_config_from_ranking().config_name

return kwargs


def get_deploy_kwargs(
model_id: str,
model_version: Optional[str] = None,
Expand Down Expand Up @@ -808,5 +831,6 @@ def get_init_kwargs(
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)

return model_init_kwargs
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _validate_model_id_and_type():
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
self.region = model_init_kwargs.region
self.sagemaker_session = model_init_kwargs.sagemaker_session
self.config_name = config_name
self.config_name = model_init_kwargs.config_name

if self.model_type == JumpStartModelType.PROPRIETARY:
self.log_subscription_warning()
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,10 +1076,12 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
"benchmark_metrics",
"config_components",
"resolved_metadata_config",
"config_name",
]

def __init__(
self,
config_name: str,
base_fields: Dict[str, Any],
config_components: Dict[str, JumpStartConfigComponent],
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
Expand All @@ -1098,6 +1100,7 @@ def __init__(
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
self.config_name: Optional[str] = config_name

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataConfig object."""
Expand Down Expand Up @@ -1251,6 +1254,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
json_obj,
(
{
Expand Down Expand Up @@ -1303,6 +1307,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
json_obj,
(
{
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,8 @@ def test_model_initialization_with_config_name(
mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

Expand Down Expand Up @@ -1592,6 +1594,8 @@ def test_model_set_deployment_config(
mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id)

assert model.config_name == None

model.deploy()

Expand All @@ -1610,6 +1614,8 @@ def test_model_set_deployment_config(
mock_model_deploy.reset_mock()
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
model.set_deployment_config("neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

Expand Down Expand Up @@ -1652,6 +1658,8 @@ def test_model_unset_deployment_config(
mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

Expand Down