Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def read_requirements(filename):
# Declare minimal set for installation
required_packages = [
"attrs>=23.1.0,<24",
"boto3>=1.26.131,<2.0",
"boto3>=1.29.6,<2.0",
"cloudpickle==2.2.1",
"google-pasta",
"numpy>=1.9.0,<2.0",
Expand Down
31 changes: 31 additions & 0 deletions src/sagemaker/image_uri_config/djl-deepspeed.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
{
"scope": ["inference"],
"versions": {
"0.25.0": {
"registries": {
"af-south-1": "626614931356",
"il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "djl-inference",
"tag_prefix": "0.25.0-deepspeed0.11.0-cu118"
},
"0.24.0": {
"registries": {
"af-south-1": "626614931356",
Expand Down
31 changes: 31 additions & 0 deletions src/sagemaker/image_uri_config/djl-neuronx.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
{
"scope": ["inference"],
"versions": {
"0.25.0": {
"registries": {
"af-south-1": "626614931356",
"il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "djl-inference",
"tag_prefix": "0.25.0-neuronx-sdk2.15.0"
},
"0.24.0": {
"registries": {
"af-south-1": "626614931356",
Expand Down
36 changes: 36 additions & 0 deletions src/sagemaker/image_uri_config/djl-tensorrtllm.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"scope": ["inference"],
"versions": {
"0.25.0": {
"registries": {
"af-south-1": "626614931356",
"il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "djl-inference",
"tag_prefix": "0.25.0-tensorrtllm0.5.0-cu122"
}
}
}
89 changes: 85 additions & 4 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
get_jumpstart_gated_content_bucket,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -102,10 +104,89 @@ def _retrieve_default_environment_variables(
elif script == JumpStartScriptScope.TRAINING and getattr(
model_specs, "training_instance_type_variants", None
):
default_environment_variables.update(
model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
instance_type
)
instance_specific_environment_variables = model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
instance_type
)

default_environment_variables.update(instance_specific_environment_variables)

gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
model_id=model_id,
model_version=model_version,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
)

if gated_model_env_var is not None:
default_environment_variables.update(
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var}
)

return default_environment_variables


def _retrieve_gated_model_uri_env_var_value(
model_id: str,
model_version: str,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
) -> Optional[str]:
"""Retrieves the gated model env var URI matching the given arguments.

Args:
model_id (str): JumpStart model ID of the JumpStart model for which to
retrieve the gated model env var URI.
model_version (str): Version of the JumpStart model for which to retrieve the
gated model env var URI.
region (Optional[str]): Region for which to retrieve the gated model env var URI.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated. (Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
environment variables specific for the instance type.

Returns:
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
have gated training artifacts.

Raises:
ValueError: If the model specs specified are invalid.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

s3_key: Optional[
str
] = model_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( # noqa E501 # pylint: disable=c0301
instance_type
)
if s3_key is None:
return None

return f"s3://{get_jumpstart_gated_content_bucket(region)}/{s3_key}"
27 changes: 19 additions & 8 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
MODEL_ID_LIST_WEB_URL,
)
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
from sagemaker.jumpstart.parameters import (
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
Expand Down Expand Up @@ -172,7 +174,7 @@ def _get_manifest_key_from_model_id_semantic_version(

manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content
)[0].formatted_content

sm_version = utils.get_sagemaker_version()

Expand Down Expand Up @@ -330,7 +332,7 @@ def _retrieval_function(
if file_type == JumpStartS3FileType.SPECS:
formatted_body, _ = self._get_json_file(s3_key, file_type)
model_specs = JumpStartModelSpecs(formatted_body)
utils.emit_logs_based_on_model_specs(model_specs, self.get_region())
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
return JumpStartCachedS3ContentValue(
formatted_content=model_specs
)
Expand All @@ -343,7 +345,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]:

manifest_dict = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content
)[0].formatted_content
manifest = list(manifest_dict.values()) # type: ignore
return manifest

Expand Down Expand Up @@ -403,10 +405,11 @@ def _get_header_impl(

versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get(
JumpStartVersionedModelId(model_id, semantic_version_str)
)
)[0]

manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content
)[0].formatted_content
try:
header = manifest[versioned_model_id] # type: ignore
return header
Expand All @@ -427,10 +430,18 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS

header = self.get_header(model_id, semantic_version_str)
spec_key = header.spec_key
specs = self._s3_cache.get(
specs, cache_hit = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
).formatted_content
return specs # type: ignore
)
if not cache_hit and "*" in semantic_version_str:
JUMPSTART_LOGGER.warning(
get_wildcard_model_version_msg(
header.model_id,
semantic_version_str,
header.version
)
)
return specs.formatted_content

def clear(self) -> None:
"""Clears the model ID/version and s3 cache."""
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)

ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = (
Expand Down
29 changes: 29 additions & 0 deletions src/sagemaker/jumpstart/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,35 @@
)


_MAJOR_VERSION_WARNING_MSG = (
"Note that models may have different input/output signatures after a major version upgrade."
)


def get_wildcard_model_version_msg(
model_id: str, wildcard_model_version: str, full_model_version: str
) -> str:
"""Returns customer-facing message for using a model version with a wildcard character."""

return (
f"Using model '{model_id}' with wildcard version identifier '{wildcard_model_version}'. "
f"You can pin to version '{full_model_version}' "
f"for more stable results. {_MAJOR_VERSION_WARNING_MSG}"
)


def get_old_model_version_msg(
model_id: str, current_model_version: str, latest_model_version: str
) -> str:
"""Returns customer-facing message associated with using an old model version."""

return (
f"Using model '{model_id}' with version '{current_model_version}'. "
f"You can upgrade to version '{latest_model_version}' to get the latest model "
f"specifications. {_MAJOR_VERSION_WARNING_MSG}"
)


class JumpStartHyperparametersError(ValueError):
"""Exception raised for bad hyperparameters of a JumpStart model."""

Expand Down
29 changes: 23 additions & 6 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Dict, List, Optional, Union
from sagemaker import (
environment_variables,
hyperparameters as hyperparameters_utils,
image_uris,
instance_types,
Expand Down Expand Up @@ -557,6 +558,18 @@ def _add_env_to_kwargs(
) -> JumpStartEstimatorInitKwargs:
"""Sets environment in kwargs based on default or override, returns full kwargs."""

extra_env_vars = environment_variables.retrieve_default(
model_id=kwargs.model_id,
model_version=kwargs.model_version,
region=kwargs.region,
include_aws_sdk_env_vars=False,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
script=JumpStartScriptScope.TRAINING,
instance_type=kwargs.instance_type,
)

model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri(
model_id=kwargs.model_id,
model_version=kwargs.model_version,
Expand All @@ -568,12 +581,16 @@ def _add_env_to_kwargs(
)

if model_package_artifact_uri:
if kwargs.environment is None:
kwargs.environment = {}
kwargs.environment = {
**{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: model_package_artifact_uri},
**kwargs.environment,
}
extra_env_vars.update(
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: model_package_artifact_uri}
)

for key, value in extra_env_vars.items():
kwargs.environment = update_dict_if_key_not_present(
kwargs.environment,
key,
value,
)

return kwargs

Expand Down
Loading