Skip to content
Merged
160 changes: 85 additions & 75 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,17 @@
from sagemaker.jumpstart.types import (
JumpStartSerializablePayload,
DeploymentConfigMetadata,
JumpStartBenchmarkStat,
JumpStartMetadataConfig,
)
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
verify_model_region_and_return_specs,
get_jumpstart_configs,
get_metrics_from_deployment_configs,
add_instance_rate_stats_to_benchmark_metrics,
)
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour
from sagemaker.utils import stringify_object, format_tags, Tags
from sagemaker.model import (
Model,
ModelPackage,
Expand Down Expand Up @@ -361,17 +360,13 @@ def _validate_model_id_and_type():
self.model_package_arn = model_init_kwargs.model_package_arn
self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)

metadata_configs = get_jumpstart_configs(
self._metadata_configs = get_jumpstart_configs(
region=self.region,
model_id=self.model_id,
model_version=self.model_version,
sagemaker_session=self.sagemaker_session,
model_type=self.model_type,
)
self._deployment_configs = [
self._convert_to_deployment_config_metadata(config_name, config)
for config_name, config in metadata_configs.items()
]

def log_subscription_warning(self) -> None:
"""Log message prompting the customer to subscribe to the proprietary model."""
Expand Down Expand Up @@ -449,33 +444,46 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:

@property
def deployment_config(self) -> Optional[Dict[str, Any]]:
"""The deployment config that will be applied to the model.
"""The deployment config that will be applied to ``This`` model.

Returns:
Optional[Dict[str, Any]]: Deployment config that will be applied to the model.
Optional[Dict[str, Any]]: Deployment config.
"""
return self._retrieve_selected_deployment_config(self.config_name)
deployment_config = self._retrieve_selected_deployment_config(
self.config_name, self.instance_type
)
return deployment_config.to_json() if deployment_config is not None else None

@property
def benchmark_metrics(self) -> pd.DataFrame:
"""Benchmark Metrics for deployment configs
"""Benchmark Metrics for deployment configs.

Returns:
Metrics: Pandas DataFrame object.
Benchmark Metrics: Pandas DataFrame object.
"""
return pd.DataFrame(self._get_benchmarks_data(self.config_name))
benchmark_metrics_data = self._get_deployment_configs_benchmarks_data(
self.config_name, self.instance_type
)
keys = list(benchmark_metrics_data.keys())
df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[0], keys[1]])
return df

def display_benchmark_metrics(self) -> None:
"""Display Benchmark Metrics for deployment configs."""
print(self.benchmark_metrics.to_markdown())
"""Display deployment configs benchmark metrics."""
print(self.benchmark_metrics.to_markdown(index=False))

def list_deployment_configs(self) -> List[Dict[str, Any]]:
"""List deployment configs for ``This`` model.

Returns:
List[Dict[str, Any]]: A list of deployment configs.
"""
return self._deployment_configs
return [
deployment_config.to_json()
for deployment_config in self._get_deployment_configs(
self.config_name, self.instance_type
)
]

def _create_sagemaker_model(
self,
Expand Down Expand Up @@ -866,92 +874,94 @@ def register_deploy_wrapper(*args, **kwargs):
return model_package

@lru_cache
def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]:
def _get_deployment_configs_benchmarks_data(
self, config_name: str, instance_type: str
) -> Dict[str, Any]:
"""Deployment configs benchmark metrics.

Args:
config_name (str): The name of the selected deployment config.
config_name (str): Name of selected deployment config.
instance_type (str): The selected Instance type.
Returns:
Dict[str, List[str]]: Deployment config benchmark data.
"""
return get_metrics_from_deployment_configs(
self._deployment_configs,
config_name,
self._get_deployment_configs(config_name, instance_type)
)

@lru_cache
def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]:
"""Retrieve the deployment config to apply to the model.
def _retrieve_selected_deployment_config(
self, config_name: str, instance_type: str
) -> Optional[DeploymentConfigMetadata]:
"""Retrieve the deployment config to apply to `This` model.

Args:
config_name (str): The name of the deployment config to retrieve.
instance_type (str): The instance type of the deployment config to retrieve.
Returns:
Optional[Dict[str, Any]]: The retrieved deployment config.
"""
if config_name is None:
return None

for deployment_config in self._deployment_configs:
if deployment_config.get("DeploymentConfigName") == config_name:
for deployment_config in self._get_deployment_configs(config_name, instance_type):
if deployment_config.deployment_config_name == config_name:
return deployment_config
return None

def _convert_to_deployment_config_metadata(
self, config_name: str, metadata_config: JumpStartMetadataConfig
) -> Dict[str, Any]:
"""Retrieve deployment config for config name.
@lru_cache
def _get_deployment_configs(
self, selected_config_name: str, selected_instance_type: str
) -> List[DeploymentConfigMetadata]:
"""Retrieve deployment configs metadata.

Args:
config_name (str): Name of deployment config.
metadata_config (JumpStartMetadataConfig): Metadata config for deployment config.
Returns:
A deployment metadata config for config name (dict[str, Any]).
selected_config_name (str): The name of the selected deployment config.
selected_instance_type (str): The selected instance type.
"""
default_inference_instance_type = metadata_config.resolved_config.get(
"default_inference_instance_type"
)

benchmark_metrics = (
metadata_config.benchmark_metrics.get(default_inference_instance_type)
if metadata_config.benchmark_metrics is not None
else None
)

should_fetch_instance_rate_metric = True
if benchmark_metrics is not None:
for benchmark_metric in benchmark_metrics:
if benchmark_metric.name.lower() == "instance rate":
should_fetch_instance_rate_metric = False
break

if should_fetch_instance_rate_metric:
instance_rate = get_instance_rate_per_hour(
instance_type=default_inference_instance_type, region=self.region
deployment_configs = []
if self._metadata_configs is None:
return deployment_configs

err = None
for config_name, metadata_config in self._metadata_configs.items():
if err is None or "is not authorized to perform: pricing:GetProducts" not in err:
err, metadata_config.benchmark_metrics = (
add_instance_rate_stats_to_benchmark_metrics(
self.region, metadata_config.benchmark_metrics
)
)

resolved_config = metadata_config.resolved_config
if selected_config_name == config_name:
instance_type_to_use = selected_instance_type
else:
instance_type_to_use = resolved_config.get("default_inference_instance_type")

init_kwargs = get_init_kwargs(
model_id=self.model_id,
instance_type=instance_type_to_use,
sagemaker_session=self.sagemaker_session,
)
if instance_rate is not None:
instance_rate_metric = JumpStartBenchmarkStat(instance_rate)

if benchmark_metrics is None:
benchmark_metrics = [instance_rate_metric]
else:
benchmark_metrics.append(instance_rate_metric)

init_kwargs = get_init_kwargs(
model_id=self.model_id,
instance_type=default_inference_instance_type,
sagemaker_session=self.sagemaker_session,
)
deploy_kwargs = get_deploy_kwargs(
model_id=self.model_id,
instance_type=default_inference_instance_type,
sagemaker_session=self.sagemaker_session,
)
deploy_kwargs = get_deploy_kwargs(
model_id=self.model_id,
instance_type=instance_type_to_use,
sagemaker_session=self.sagemaker_session,
)
deployment_config_metadata = DeploymentConfigMetadata(
config_name,
metadata_config.benchmark_metrics,
resolved_config,
init_kwargs,
deploy_kwargs,
)
deployment_configs.append(deployment_config_metadata)

deployment_config_metadata = DeploymentConfigMetadata(
config_name, benchmark_metrics, init_kwargs, deploy_kwargs
)
if err is not None and "is not authorized to perform: pricing:GetProducts" in err:
error_message = "Instance rate metrics will be omitted. Reason: %s"
JUMPSTART_LOGGER.warning(error_message, err)

return deployment_config_metadata.to_json()
return deployment_configs

def __str__(self) -> str:
"""Overriding str(*) method to make more human-readable."""
Expand Down
72 changes: 45 additions & 27 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2235,29 +2235,37 @@ def to_json(self) -> Dict[str, Any]:
if hasattr(self, att):
cur_val = getattr(self, att)
att = self._convert_to_pascal_case(att)
if issubclass(type(cur_val), JumpStartDataHolderType):
json_obj[att] = cur_val.to_json()
elif isinstance(cur_val, list):
json_obj[att] = []
for obj in cur_val:
if issubclass(type(obj), JumpStartDataHolderType):
json_obj[att].append(obj.to_json())
else:
json_obj[att].append(obj)
elif isinstance(cur_val, dict):
json_obj[att] = {}
for key, val in cur_val.items():
if issubclass(type(val), JumpStartDataHolderType):
json_obj[att][self._convert_to_pascal_case(key)] = val.to_json()
else:
json_obj[att][key] = val
else:
json_obj[att] = cur_val
json_obj[att] = self._val_to_json(cur_val)
return json_obj

def _val_to_json(self, val: Any) -> Any:
"""Converts the given value to JSON.

Args:
val (Any): The value to convert.
Returns:
Any: The converted json value.
"""
if issubclass(type(val), JumpStartDataHolderType):
return val.to_json()
if isinstance(val, list):
list_obj = []
for obj in val:
list_obj.append(self._val_to_json(obj))
return list_obj
if isinstance(val, dict):
dict_obj = {}
for k, v in val.items():
if isinstance(v, JumpStartDataHolderType):
dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v)
else:
dict_obj[k] = self._val_to_json(v)
return dict_obj
return val


class DeploymentArgs(BaseDeploymentConfigDataHolder):
"""Dataclass representing a Deployment Config."""
"""Dataclass representing a Deployment Args."""

__slots__ = [
"image_uri",
Expand All @@ -2270,9 +2278,12 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder):
]

def __init__(
self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs
self,
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
resolved_config: Optional[Dict[str, Any]] = None,
):
"""Instantiates DeploymentConfig object."""
"""Instantiates DeploymentArgs object."""
if init_kwargs is not None:
self.image_uri = init_kwargs.image_uri
self.model_data = init_kwargs.model_data
Expand All @@ -2287,6 +2298,11 @@ def __init__(
self.container_startup_health_check_timeout = (
deploy_kwargs.container_startup_health_check_timeout
)
if resolved_config is not None:
self.default_instance_type = resolved_config.get("default_inference_instance_type")
self.supported_instance_types = resolved_config.get(
"supported_inference_instance_types"
)


class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
Expand All @@ -2301,13 +2317,15 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):

def __init__(
self,
config_name: str,
benchmark_metrics: List[JumpStartBenchmarkStat],
init_kwargs: JumpStartModelInitKwargs,
deploy_kwargs: JumpStartModelDeployKwargs,
config_name: Optional[str] = None,
benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]] = None,
resolved_config: Optional[Dict[str, Any]] = None,
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
):
"""Instantiates DeploymentConfigMetadata object."""
self.deployment_config_name = config_name
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs)
self.acceleration_configs = None
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs, resolved_config)
self.benchmark_metrics = benchmark_metrics
if resolved_config is not None:
self.acceleration_configs = resolved_config.get("acceleration_configs")
Loading