Skip to content
26 changes: 26 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,31 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
return kwargs


def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
"""Sets default config name to the kwargs. Returns full kwargs."""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)
if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)

return kwargs


def get_deploy_kwargs(
model_id: str,
model_version: Optional[str] = None,
Expand Down Expand Up @@ -808,5 +833,6 @@ def get_init_kwargs(
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)

return model_init_kwargs
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _validate_model_id_and_type():
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
self.region = model_init_kwargs.region
self.sagemaker_session = model_init_kwargs.sagemaker_session
self.config_name = config_name
self.config_name = model_init_kwargs.config_name

if self.model_type == JumpStartModelType.PROPRIETARY:
self.log_subscription_warning()
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,10 +1076,12 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
"benchmark_metrics",
"config_components",
"resolved_metadata_config",
"config_name",
]

def __init__(
self,
config_name: str,
base_fields: Dict[str, Any],
config_components: Dict[str, JumpStartConfigComponent],
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
Expand All @@ -1098,6 +1100,7 @@ def __init__(
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
self.config_name: Optional[str] = config_name

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataConfig object."""
Expand Down Expand Up @@ -1251,6 +1254,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
json_obj,
(
{
Expand Down Expand Up @@ -1303,6 +1307,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
json_obj,
(
{
Expand Down
21 changes: 18 additions & 3 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,8 @@ def test_model_initialization_with_config_name(

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1594,6 +1596,8 @@ def test_model_set_deployment_config(

model = JumpStartModel(model_id=model_id)

assert model.config_name is None

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand All @@ -1612,6 +1616,8 @@ def test_model_set_deployment_config(
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
model.set_deployment_config("neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1654,6 +1660,8 @@ def test_model_unset_deployment_config(

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1789,7 +1797,6 @@ def test_model_retrieve_deployment_config(
):
model_id, _ = "pytorch-eqa-bert-base-cased", "*"

mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
mock_verify_model_region_and_return_specs.side_effect = (
lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks()
)
Expand All @@ -1804,15 +1811,23 @@ def test_model_retrieve_deployment_config(
)
mock_model_deploy.return_value = default_predictor

expected = get_base_deployment_configs()[0]
config_name = expected.get("DeploymentConfigName")
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(
model_id, config_name
)

mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id)

expected = get_base_deployment_configs()[0]
model.set_deployment_config(expected.get("DeploymentConfigName"))
model.set_deployment_config(config_name)

self.assertEqual(model.deployment_config, expected)

mock_get_init_kwargs.reset_mock()
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)

# Unset
model.set_deployment_config(None)
self.assertIsNone(model.deployment_config)
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import copy
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
import boto3

from sagemaker.compute_resource_requirements import ResourceRequirements
Expand Down Expand Up @@ -237,7 +237,7 @@ def get_base_spec_with_prototype_configs_with_missing_benchmarks(
copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS)
copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None

inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}
inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS}
training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS}

spec.update(inference_configs)
Expand Down Expand Up @@ -335,7 +335,9 @@ def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, An
return configs


def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs:
def get_mock_init_kwargs(
model_id: str, config_name: Optional[str] = None
) -> JumpStartModelInitKwargs:
return JumpStartModelInitKwargs(
model_id=model_id,
model_type=JumpStartModelType.OPEN_WEIGHTS,
Expand All @@ -344,4 +346,5 @@ def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs:
instance_type=INIT_KWARGS.get("instance_type"),
env=INIT_KWARGS.get("env"),
resources=ResourceRequirements(),
config_name=config_name,
)