diff --git a/setup.py b/setup.py index cbdc5cdfc6..551431fbde 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ "attrs>=23.1.0,<24", - "boto3>=1.26.131,<2.0", + "boto3>=1.29.6,<2.0", "cloudpickle==2.2.1", "google-pasta", "numpy>=1.9.0,<2.0", diff --git a/src/sagemaker/image_uri_config/djl-deepspeed.json b/src/sagemaker/image_uri_config/djl-deepspeed.json index b78ffaa3eb..0172953da3 100644 --- a/src/sagemaker/image_uri_config/djl-deepspeed.json +++ b/src/sagemaker/image_uri_config/djl-deepspeed.json @@ -1,6 +1,37 @@ { "scope": ["inference"], "versions": { + "0.25.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "djl-inference", + "tag_prefix": "0.25.0-deepspeed0.11.0-cu118" + }, "0.24.0": { "registries": { "af-south-1": "626614931356", diff --git a/src/sagemaker/image_uri_config/djl-neuronx.json b/src/sagemaker/image_uri_config/djl-neuronx.json index ae3fd4820c..26567bfb16 100644 --- a/src/sagemaker/image_uri_config/djl-neuronx.json +++ b/src/sagemaker/image_uri_config/djl-neuronx.json @@ -1,6 +1,37 @@ { "scope": ["inference"], "versions": { + "0.25.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "djl-inference", + "tag_prefix": "0.25.0-neuronx-sdk2.15.0" + }, "0.24.0": { "registries": { "af-south-1": "626614931356", diff --git a/src/sagemaker/image_uri_config/djl-tensorrtllm.json b/src/sagemaker/image_uri_config/djl-tensorrtllm.json new file mode 100644 index 0000000000..5cb7bcfe38 --- /dev/null +++ b/src/sagemaker/image_uri_config/djl-tensorrtllm.json @@ -0,0 +1,36 @@ +{ + "scope": ["inference"], + "versions": { + "0.25.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "djl-inference", + "tag_prefix": "0.25.0-tensorrtllm0.5.0-cu122" + } + } +} diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index be210bf7a8..0e666e4c14 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -16,11 +16,13 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, + SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( + get_jumpstart_gated_content_bucket, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -102,10 +104,89 @@ def _retrieve_default_environment_variables( elif script == JumpStartScriptScope.TRAINING and getattr( model_specs, "training_instance_type_variants", None ): - default_environment_variables.update( - model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301 - instance_type - ) + instance_specific_environment_variables = model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301 + instance_type ) + default_environment_variables.update(instance_specific_environment_variables) + + gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value( + model_id=model_id, + model_version=model_version, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + instance_type=instance_type, + ) + + if gated_model_env_var is not None: + default_environment_variables.update( + {SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var} + ) + return default_environment_variables + + +def _retrieve_gated_model_uri_env_var_value( + model_id: str, + model_version: str, + region: Optional[str] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + instance_type: Optional[str] = None, +) -> Optional[str]: + """Retrieves the gated model env var URI matching the given arguments. + + Args: + model_id (str): JumpStart model ID of the JumpStart model for which to + retrieve the gated model env var URI. + model_version (str): Version of the JumpStart model for which to retrieve the + gated model env var URI. + region (Optional[str]): Region for which to retrieve the gated model env var URI. + (Default: None). + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + instance_type (str): An instance type to optionally supply in order to get + environment variables specific for the instance type. + + Returns: + Optional[str]: the s3 URI to use for the environment variable, or None if the model does not + have gated training artifacts. + + Raises: + ValueError: If the model specs specified are invalid. + """ + + if region is None: + region = JUMPSTART_DEFAULT_REGION_NAME + + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=JumpStartScriptScope.TRAINING, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + ) + + s3_key: Optional[ + str + ] = model_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( # noqa E501 # pylint: disable=c0301 + instance_type + ) + if s3_key is None: + return None + + return f"s3://{get_jumpstart_gated_content_bucket(region)}/{s3_key}" diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 25206645f2..e26d588167 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -26,8 +26,10 @@ ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_REGION_NAME, + JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, ) +from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, @@ -172,7 +174,7 @@ def _get_manifest_key_from_model_id_semantic_version( manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) - ).formatted_content + )[0].formatted_content sm_version = utils.get_sagemaker_version() @@ -330,7 +332,7 @@ def _retrieval_function( if file_type == JumpStartS3FileType.SPECS: formatted_body, _ = self._get_json_file(s3_key, file_type) model_specs = JumpStartModelSpecs(formatted_body) - utils.emit_logs_based_on_model_specs(model_specs, self.get_region()) + utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) return JumpStartCachedS3ContentValue( formatted_content=model_specs ) @@ -343,7 +345,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]: manifest_dict = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) - ).formatted_content + )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest @@ -403,10 +405,11 @@ def _get_header_impl( versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( JumpStartVersionedModelId(model_id, semantic_version_str) - ) + )[0] + manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) - ).formatted_content + )[0].formatted_content try: header = manifest[versioned_model_id] # type: ignore return header @@ -427,10 +430,18 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS header = self.get_header(model_id, semantic_version_str) spec_key = header.spec_key - specs = self._s3_cache.get( + specs, cache_hit = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) - ).formatted_content - return specs # type: ignore + ) + if not cache_hit and "*" in semantic_version_str: + JUMPSTART_LOGGER.warning( + get_wildcard_model_version_msg( + header.model_id, + semantic_version_str, + header.version + ) + ) + return specs.formatted_content def clear(self) -> None: """Clears the model ID/version and s3 cache.""" diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 79f416215b..e660cd65cc 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -168,6 +168,7 @@ SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE" +ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE" ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE" ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE" ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = ( diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index b9bb530373..aef540b8b0 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -30,6 +30,35 @@ ) +_MAJOR_VERSION_WARNING_MSG = ( + "Note that models may have different input/output signatures after a major version upgrade." +) + + +def get_wildcard_model_version_msg( + model_id: str, wildcard_model_version: str, full_model_version: str +) -> str: + """Returns customer-facing message for using a model version with a wildcard character.""" + + return ( + f"Using model '{model_id}' with wildcard version identifier '{wildcard_model_version}'. " + f"You can pin to version '{full_model_version}' " + f"for more stable results. {_MAJOR_VERSION_WARNING_MSG}" + ) + + +def get_old_model_version_msg( + model_id: str, current_model_version: str, latest_model_version: str +) -> str: + """Returns customer-facing message associated with using an old model version.""" + + return ( + f"Using model '{model_id}' with version '{current_model_version}'. " + f"You can upgrade to version '{latest_model_version}' to get the latest model " + f"specifications. {_MAJOR_VERSION_WARNING_MSG}" + ) + + class JumpStartHyperparametersError(ValueError): """Exception raised for bad hyperparameters of a JumpStart model.""" diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 1b24b714e7..c7396cdec5 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Union from sagemaker import ( + environment_variables, hyperparameters as hyperparameters_utils, image_uris, instance_types, @@ -557,6 +558,18 @@ def _add_env_to_kwargs( ) -> JumpStartEstimatorInitKwargs: """Sets environment in kwargs based on default or override, returns full kwargs.""" + extra_env_vars = environment_variables.retrieve_default( + model_id=kwargs.model_id, + model_version=kwargs.model_version, + region=kwargs.region, + include_aws_sdk_env_vars=False, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, + script=JumpStartScriptScope.TRAINING, + instance_type=kwargs.instance_type, + ) + model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, @@ -568,12 +581,16 @@ def _add_env_to_kwargs( ) if model_package_artifact_uri: - if kwargs.environment is None: - kwargs.environment = {} - kwargs.environment = { - **{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: model_package_artifact_uri}, - **kwargs.environment, - } + extra_env_vars.update( + {SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: model_package_artifact_uri} + ) + + for key, value in extra_env_vars.items(): + kwargs.environment = update_dict_if_key_not_present( + kwargs.environment, + key, + value, + ) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 19605774ed..88934e91df 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -489,6 +489,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + accept_eula: Optional[bool] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -516,6 +517,7 @@ def get_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + accept_eula=accept_eula, ) deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index c4bd62686c..204e4d3299 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -448,6 +448,7 @@ def deploy( container_startup_health_check_timeout: Optional[int] = None, inference_recommendation_id: Optional[str] = None, explainer_config: Optional[ExplainerConfig] = None, + accept_eula: Optional[bool] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -526,7 +527,11 @@ def deploy( (Default: None). explainer_config (Optional[sagemaker.explainer.ExplainerConfig]): Specifies online explainability configuration for use with Amazon SageMaker Clarify. (Default: None). - + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ deploy_kwargs = get_deploy_kwargs( @@ -553,6 +558,7 @@ def deploy( inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, sagemaker_session=self.sagemaker_session, + accept_eula=accept_eula, ) predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index c4b51cc8b8..8a601eafdc 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -581,6 +581,16 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic return instance_family_environment_variables + def get_instance_specific_gated_model_key_env_var_value( + self, instance_type: str + ) -> Optional[str]: + """Returns instance specific gated model env var s3 key. + + Returns None if a model, instance type tuple does not have instance + specific property. + """ + return self._get_instance_specific_property(instance_type, "gated_model_key_env_var_value") + def get_instance_specific_default_inference_instance_type( self, instance_type: str ) -> Optional[str]: @@ -901,10 +911,12 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" - return ( - self.training_model_package_artifact_uris is None - or len(self.training_model_package_artifact_uris) == 0 - ) + # gated model never use training model artifact + if self.gated_bucket: + return False + + # otherwise, return true is a training model package is not set + return len(self.training_model_package_artifact_uris or {}) == 0 def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" @@ -1120,6 +1132,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "tolerate_deprecated_model", "sagemaker_session", "training_instance_type", + "accept_eula", ] SERIALIZATION_EXCLUSION_SET = { @@ -1158,6 +1171,7 @@ def __init__( tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, training_instance_type: Optional[str] = None, + accept_eula: Optional[bool] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1185,6 +1199,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.sagemaker_session = sagemaker_session self.training_instance_type = training_instance_type + self.accept_eula = accept_eula class JumpStartEstimatorInitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 572230cf50..cd4ffcd702 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -16,6 +16,7 @@ import os from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse +import boto3 from packaging.version import Version import sagemaker from sagemaker.config.config_schema import ( @@ -31,6 +32,7 @@ from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, VulnerableJumpStartModelError, + get_old_model_version_msg, ) from sagemaker.jumpstart.types import ( JumpStartModelHeader, @@ -81,13 +83,13 @@ def get_jumpstart_gated_content_bucket( gated_bucket_to_return: Optional[str] = None if ( - constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ - and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0 + constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE in os.environ + and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE]) > 0 ): gated_bucket_to_return = os.environ[ - constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE + constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE ] - info_logs.append(f"Using JumpStart private bucket override: '{gated_bucket_to_return}'") + info_logs.append(f"Using JumpStart gated bucket override: '{gated_bucket_to_return}'") else: try: gated_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[ @@ -462,7 +464,9 @@ def update_inference_tags_with_jumpstart_training_tags( return inference_tags -def emit_logs_based_on_model_specs(model_specs: JumpStartModelSpecs, region: str) -> None: +def emit_logs_based_on_model_specs( + model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client +) -> None: """Emits logs based on model specs and region.""" if model_specs.hosting_eula_key: @@ -476,6 +480,24 @@ def emit_logs_based_on_model_specs(model_specs: JumpStartModelSpecs, region: str model_specs.hosting_eula_key, ) + full_version: str = model_specs.version + + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, s3_client=s3_client + ) + max_version_for_model_id: Optional[str] = None + for header in models_manifest_list: + if header.model_id == model_specs.model_id: + if max_version_for_model_id is None or Version(header.version) > Version( + max_version_for_model_id + ): + max_version_for_model_id = header.version + + if full_version != max_version_for_model_id: + constants.JUMPSTART_LOGGER.info( + get_old_model_version_msg(model_specs.model_id, full_version, max_version_for_model_id) + ) + if model_specs.deprecated: deprecated_message = model_specs.deprecated_message or ( "Using deprecated JumpStart model " @@ -589,11 +611,18 @@ def verify_model_region_and_return_specs( def update_dict_if_key_not_present( - dict_to_update: dict, key_to_add: Any, value_to_add: Any -) -> dict: - """If a key is not present in the dict, add the new (key, value) pair, and return dict.""" + dict_to_update: Optional[dict], key_to_add: Any, value_to_add: Any +) -> Optional[dict]: + """If a key is not present in the dict, add the new (key, value) pair, and return dict. + + If dict is empty, return None. + """ + if dict_to_update is None: + dict_to_update = {} if key_to_add not in dict_to_update: dict_to_update[key_to_add] = value_to_add + if dict_to_update == {}: + dict_to_update = None return dict_to_update @@ -726,13 +755,5 @@ def is_valid_model_id( if script == enums.JumpStartScriptScope.INFERENCE: return model_id in model_id_set if script == enums.JumpStartScriptScope.TRAINING: - return ( - model_id in model_id_set - and accessors.JumpStartModelsAccessor.get_model_specs( - region=region, - model_id=model_id, - version=model_version, - s3_client=s3_client, - ).training_supported - ) + return model_id in model_id_set raise ValueError(f"Unsupported script: {script}") diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ee9d562b2b..5e0afab942 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -379,6 +379,7 @@ def __init__( self.repacked_model_data = None self.content_types = None self.response_types = None + self.accept_eula = None @runnable_by_pipeline def register( @@ -634,6 +635,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, image_config=self.image_config, + accept_eula=getattr(self, "accept_eula", None), ) def is_repack(self) -> bool: @@ -1260,6 +1262,7 @@ def deploy( container_startup_health_check_timeout=None, inference_recommendation_id=None, explainer_config=None, + accept_eula: Optional[bool] = None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1342,6 +1345,11 @@ def deploy( a list of ``RealtimeInferenceRecommendations`` within ``DeploymentRecommendation`` explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability configuration for use with Amazon SageMaker Clarify. Default: None. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1355,6 +1363,8 @@ def deploy( ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ + self.accept_eula = accept_eula + removed_kwargs("update_endpoint", kwargs) self._init_sagemaker_session_if_does_not_exist(instance_type) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ca1721942f..4389af3a36 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -6128,7 +6128,14 @@ def update_args(args: Dict[str, Any], **kwargs): args.update({key: value}) -def container_def(image_uri, model_data_url=None, env=None, container_mode=None, image_config=None): +def container_def( + image_uri, + model_data_url=None, + env=None, + container_mode=None, + image_config=None, + accept_eula=None, +): """Create a definition for executing a container as part of a SageMaker model. Args: @@ -6145,6 +6152,11 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None, image_config (dict[str, str]): Specifies whether the image of model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (default: None). + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if @@ -6154,9 +6166,28 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None, env = {} c_def = {"Image": image_uri, "Environment": env} - if isinstance(model_data_url, dict): - c_def["ModelDataSource"] = model_data_url - elif model_data_url: + if isinstance(model_data_url, str) and ( + not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz")) + or accept_eula is None + ): + c_def["ModelDataUrl"] = model_data_url + + elif isinstance(model_data_url, (dict, str)): + if isinstance(model_data_url, dict): + c_def["ModelDataSource"] = model_data_url + else: + c_def["ModelDataSource"] = { + "S3DataSource": { + "S3Uri": model_data_url, + "S3DataType": "S3Object", + "CompressionType": "Gzip", + } + } + if accept_eula is not None: + c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = { + "AcceptEula": accept_eula + } + elif model_data_url is not None: c_def["ModelDataUrl"] = model_data_url if container_mode: diff --git a/src/sagemaker/utilities/cache.py b/src/sagemaker/utilities/cache.py index b5a48ccef8..d206f78963 100644 --- a/src/sagemaker/utilities/cache.py +++ b/src/sagemaker/utilities/cache.py @@ -15,7 +15,7 @@ import datetime import collections -from typing import TypeVar, Generic, Callable, Optional +from typing import Tuple, TypeVar, Generic, Callable, Optional KeyType = TypeVar("KeyType") ValType = TypeVar("ValType") @@ -86,23 +86,26 @@ def clear(self) -> None: """Deletes all elements from the cache.""" self._lru_cache.clear() - def get(self, key: KeyType, data_source_fallback: Optional[bool] = True) -> ValType: - """Returns value corresponding to key in cache. + def get( + self, key: KeyType, data_source_fallback: Optional[bool] = True + ) -> Tuple[ValType, bool]: + """Returns value corresponding to key in cache and boolean indicating cache hit. Args: key (KeyType): Key in cache to retrieve. data_source_fallback (Optional[bool]): True if data should be retrieved if it's stale or not in cache. Default: True. - Raises: - KeyError: If key is not found in cache or is outdated and - ``data_source_fallback`` is False. + + Raises: + KeyError: If key is not found in cache or is outdated and + ``data_source_fallback`` is False. """ if data_source_fallback: if key in self._lru_cache: - return self._get_item(key, False) + return self._get_item(key, False), True self.put(key) - return self._get_item(key, False) - return self._get_item(key, True) + return self._get_item(key, False), False + return self._get_item(key, True), True def put(self, key: KeyType, value: Optional[ValType] = None) -> None: """Adds key to cache using ``retrieval_function``. diff --git a/tests/data/dummy_code_bundle_with_reqs/main_script.py b/tests/data/dummy_code_bundle_with_reqs/main_script.py index c4b2951adf..e0fefe0f76 100644 --- a/tests/data/dummy_code_bundle_with_reqs/main_script.py +++ b/tests/data/dummy_code_bundle_with_reqs/main_script.py @@ -6,6 +6,6 @@ import local_module print("Trying to import module from requirements.txt...") -import stepfunctions +from aws_xray_sdk.core import xray_recorder print("Done") diff --git a/tests/data/dummy_code_bundle_with_reqs/requirements.txt b/tests/data/dummy_code_bundle_with_reqs/requirements.txt index 5f3c6b7ef7..88a2abd8ce 100644 --- a/tests/data/dummy_code_bundle_with_reqs/requirements.txt +++ b/tests/data/dummy_code_bundle_with_reqs/requirements.txt @@ -1,3 +1,3 @@ -# As a test dependency, we'll use the AWS Step Functions Data Science SDK - since it's more of a -# notebook environment tool than a training/processing job library, so shouldn't be in base images -stepfunctions +# We use aws-xray-sdk as a test dependency as it shouldn't be in base images. +# It is compatible with python 3.7 to 3.11, maintained regularly and does not depend on sagemaker +aws-xray-sdk diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index d5a7d7ff00..928013150e 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import os import time +import mock import pytest from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -127,7 +128,8 @@ def test_gated_model_training(setup): assert response is not None -def test_instatiating_estimator_not_too_slow(setup): +@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") +def test_instatiating_estimator(mock_warning_logger, setup): model_id = "xgboost-classification-model" @@ -142,3 +144,5 @@ def test_instatiating_estimator_not_too_slow(setup): elapsed_time = time.perf_counter() - start_time assert elapsed_time <= MAX_INIT_TIME_SECONDS + + mock_warning_logger.assert_called_once() diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 9843b17c41..0dd48082b9 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import os import time +from unittest import mock import pytest @@ -114,7 +115,8 @@ def test_model_package_arn_jumpstart_model(setup): assert response is not None -def test_instatiating_model_not_too_slow(setup): +@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") +def test_instatiating_model(mock_warning_logger, setup): model_id = "catboost-regression-model" @@ -130,12 +132,15 @@ def test_instatiating_model_not_too_slow(setup): assert elapsed_time <= MAX_INIT_TIME_SECONDS + mock_warning_logger.assert_called_once() + def test_jumpstart_model_register(setup): model_id = "huggingface-txt2img-conflictx-complex-lineart" model = JumpStartModel( model_id=model_id, + model_version="1.1.0", role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py index 9066821949..550e2481cd 100644 --- a/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py @@ -43,18 +43,27 @@ def test_jumpstart_inference_retrieve_functions(setup): model_id=model_id, model_version=model_version, instance_type=instance_type, + tolerate_vulnerable_model=True, ) script_uri = script_uris.retrieve( - model_id=model_id, model_version=model_version, script_scope="inference" + model_id=model_id, + model_version=model_version, + script_scope="inference", + tolerate_vulnerable_model=True, ) model_uri = model_uris.retrieve( - model_id=model_id, model_version=model_version, model_scope="inference" + model_id=model_id, + model_version=model_version, + model_scope="inference", + tolerate_vulnerable_model=True, ) environment_vars = environment_variables.retrieve_default( - model_id=model_id, model_version=model_version + model_id=model_id, + model_version=model_version, + tolerate_vulnerable_model=True, ) inference_job = InferenceJobLauncher( diff --git a/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py b/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py index 406b1f95d0..eb9b75be4f 100644 --- a/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py +++ b/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py @@ -35,7 +35,10 @@ def test_jumpstart_inference_model_class(setup): model_id, model_version = "catboost-classification-model", "1.2.7" instance_type = instance_types.retrieve_default( - model_id=model_id, model_version=model_version, scope="inference" + model_id=model_id, + model_version=model_version, + scope="inference", + tolerate_vulnerable_model=True, ) instance_count = 1 @@ -48,24 +51,33 @@ def test_jumpstart_inference_model_class(setup): model_id=model_id, model_version=model_version, instance_type=instance_type, + tolerate_vulnerable_model=True, ) script_uri = script_uris.retrieve( - model_id=model_id, model_version=model_version, script_scope="inference" + model_id=model_id, + model_version=model_version, + script_scope="inference", + tolerate_vulnerable_model=True, ) model_uri = model_uris.retrieve( - model_id=model_id, model_version=model_version, model_scope="inference" + model_id=model_id, + model_version=model_version, + model_scope="inference", + tolerate_vulnerable_model=True, ) env = environment_variables.retrieve_default( model_id=model_id, model_version=model_version, include_aws_sdk_env_vars=False, + tolerate_vulnerable_model=True, ) model_kwargs = _retrieve_model_init_kwargs( model_id=model_id, model_version=model_version, + tolerate_vulnerable_model=True, ) model = Model( @@ -83,6 +95,7 @@ def test_jumpstart_inference_model_class(setup): model_id=model_id, model_version=model_version, instance_type=instance_type, + tolerate_vulnerable_model=True, ) model.deploy( @@ -97,6 +110,7 @@ def test_jumpstart_inference_model_class(setup): model_id=model_id, model_version=model_version, sagemaker_session=get_sm_session(), + tolerate_vulnerable_model=True, ) download_inference_assets() diff --git a/tests/scripts/run-notebook-test.sh b/tests/scripts/run-notebook-test.sh index 8f33fbc97c..e2844d4728 100755 --- a/tests/scripts/run-notebook-test.sh +++ b/tests/scripts/run-notebook-test.sh @@ -136,7 +136,6 @@ echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN" --platformIdentifier notebook-al2-v2 \ --consider-skips-failures \ ./amazon-sagemaker-examples/sagemaker_processing/spark_distributed_data_processing/sagemaker-spark-processing.ipynb \ -./amazon-sagemaker-examples/advanced_functionality/tensorflow_iris_byom/tensorflow_BYOM_iris.ipynb \ ./amazon-sagemaker-examples/sagemaker-python-sdk/1P_kmeans_highlevel/kmeans_mnist.ipynb \ ./amazon-sagemaker-examples/sagemaker-python-sdk/scikit_learn_randomforest/Sklearn_on_SageMaker_end2end.ipynb \ ./amazon-sagemaker-examples/sagemaker-pipelines/tabular/abalone_build_train_deploy/sagemaker-pipelines-preprocess-train-evaluate-batch-transform.ipynb \ @@ -144,4 +143,7 @@ echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN" # Skipping test until fix in example notebook to install docker-compose is complete #./amazon-sagemaker-examples/sagemaker-python-sdk/tensorflow_moving_from_framework_mode_to_script_mode/tensorflow_moving_from_framework_mode_to_script_mode.ipynb \ +# Skipping this test until we fix the notebook to use the correct version of TensorFlow for training and inference +# ./amazon-sagemaker-examples/advanced_functionality/tensorflow_iris_byom/tensorflow_BYOM_iris.ipynb \ + (DeleteLifeCycleConfig "$LIFECYCLE_CONFIG_NAME") diff --git a/tests/unit/sagemaker/image_uris/test_djl.py b/tests/unit/sagemaker/image_uris/test_djl.py index 4665d6d8a2..6457fe044f 100644 --- a/tests/unit/sagemaker/image_uris/test_djl.py +++ b/tests/unit/sagemaker/image_uris/test_djl.py @@ -18,7 +18,12 @@ @pytest.mark.parametrize( "load_config_and_file_name", - ["djl-neuronx.json", "djl-fastertransformer.json", "djl-deepspeed.json"], + [ + "djl-neuronx.json", + "djl-fastertransformer.json", + "djl-deepspeed.json", + "djl-tensorrtllm.json", + ], indirect=True, ) def test_djl_uris(load_config_and_file_name): diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 6551497318..bf4a4cd031 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1023,6 +1023,231 @@ "training_enable_network_isolation": False, "resource_name_base": "dfsdfsds", }, + "gated_variant-model": { + "model_id": "pytorch-ic-mobilenet-v2", + "gated_bucket": True, + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "training_instance_type_variants": None, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "p2": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + }, + "properties": { + "prepacked_artifact_key": "some-instance-specific/model/prefix/" + }, + }, + "p3": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + } + }, + "p4": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + } + }, + "g4dn": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + } + }, + "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.48xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + }, + "ml.g5.12xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + }, + }, + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": None, + "hosting_prepacked_artifact_key": None, + "training_model_package_artifact_uris": None, + "deprecate_warn_message": None, + "deprecated_message": None, + "hosting_eula_key": None, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 3, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "adam-learning-rate", + "type": "float", + "default": 0.05, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "batch-size", + "type": "int", + "default": 4, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + ], + "default_training_instance_type": "ml.p3.2xlarge", + "supported_training_instance_types": [ + "ml.p3.2xlarge", + "ml.p2.xlarge", + "ml.g4dn.2xlarge", + "ml.m5.xlarge", + "ml.c5.2xlarge", + ], + "hosting_use_script_uri": False, + "metrics": [ + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + ], + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + }, + "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "predictor_specs": { + "supported_content_types": ["application/x-image"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-image", + "default_accept_type": "application/json", + }, + "inference_volume_size": 123, + "training_volume_size": 456, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": False, + "resource_name_base": "dfsdfsds", + }, "model-artifact-variant-model": { "model_id": "pytorch-ic-mobilenet-v2", "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", @@ -2103,6 +2328,470 @@ "default_accept_type": "application/json", }, }, + "js-gated-artifact-non-model-package-trainable-model": { + "model_id": "meta-textgeneration-llama-2-7b", + "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "version": "3.0.0", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.1.0", + "py_version": "py39", + }, + "training_artifact_key": "some/dummy/key", + "hosting_artifact_key": "meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "meta-textgeneration/meta-textgen" + "eration-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", + "inference_vulnerable": False, + "inference_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", + "sagemaker_jumpstart_script_utilities==1.1.8", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pytorch-triton==2.1.0+e6216047b8", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch==2.1.0.dev20230905+cu118", + "transformers==4.31.0", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "int8_quantization", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "enable_fsdp", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epoch", + "type": "int", + "default": 5, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + {"name": "lora_r", "type": "int", "default": 8, "min": 1, "scope": "algorithm"}, + {"name": "lora_alpha", "type": "int", "default": 32, "min": 1, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "instruction_tuned", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "chat_dataset", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "add_input_output_demarcation_key", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/" + "meta/transfer_learning/textgeneration/v1.0.4/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-" + "tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.0.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "4095", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "4096", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + "default_inference_instance_type": "ml.g5.2xlarge", + "supported_inference_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.12xlarge", + "supported_training_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 256, + "training_volume_size": 256, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/sec_amazon/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "meta-textgeneration-llama-2-7b", + "default_payloads": { + "meaningOfLife": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "I believe the meaning of life is", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "theoryOfRelativity": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "teamMessage": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "A brief message congratulating the team on the launch:\n\nHi everyone,\n\nI just ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "englishToFrench": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "Translate English to French:\nsea o" + "tter => loutre de mer\npeppermint => ment" + "he poivr\u00e9e\nplush girafe => girafe peluche\ncheese =>", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "Story": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Please tell me a story.", + "parameters": { + "max_new_tokens": 64, + "top_p": 0.9, + "temperature": 0.2, + "decoder_input_details": True, + "details": True, + }, + }, + }, + }, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/h" + "uggingface-pytorch-tgi-inference:2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazon" + "aws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "meta-training/train-meta-textgeneration-llama-2-7b.tar.gz", + "environment_variables": {"SELF_DESTRUCT": "true"}, + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "dynamic_container_deployment_supported": False, + }, "js-gated-artifact-trainable-model": { "model_id": "meta-textgeneration-llama-2-7b-f", "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 6f4788fa04..1010cd7cf9 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -23,6 +23,13 @@ from sagemaker.debugger.profiler_config import ProfilerConfig from sagemaker.estimator import Estimator from sagemaker.instance_group import InstanceGroup +from sagemaker.jumpstart.artifacts.environment_variables import ( + _retrieve_default_environment_variables, +) +from sagemaker.jumpstart.artifacts.hyperparameters import _retrieve_default_hyperparameters +from sagemaker.jumpstart.artifacts.metric_definitions import ( + _retrieve_default_training_metric_definitions, +) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag @@ -391,6 +398,157 @@ def test_gated_model_s3_uri( ], ) + @mock.patch( + "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" + ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_gated_model_non_model_package_s3_uri( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_timestamp: mock.Mock, + mock_get_jumpstart_gated_content_bucket: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + + mock_get_jumpstart_gated_content_bucket.return_value = "top-secret-private-models-bucket" + mock_timestamp.return_value = "8675309" + + mock_is_valid_model_id.return_value = True + + model_id, _ = "js-gated-artifact-non-model-package-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, environment={"accept_eula": True}) + + mock_estimator_init.assert_called_once_with( + instance_type="ml.g5.12xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt" + "orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-d" + "irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "int8_quantization": "False", + "enable_fsdp": "True", + "epoch": "5", + "learning_rate": "0.0001", + "lora_r": "8", + "lora_alpha": "32", + "lora_dropout": "0.05", + "instruction_tuned": "False", + "chat_dataset": "False", + "add_input_output_demarcation_key": "True", + "per_device_train_batch_size": "4", + "per_device_eval_batch_size": "1", + "max_train_samples": "-1", + "max_val_samples": "-1", + "seed": "10", + "max_input_length": "-1", + "validation_split_ratio": "0.2", + "train_data_split_seed": "0", + "preprocessing_num_workers": "None", + }, + metric_definitions=[ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + role="fake role! do not use!", + max_run=360000, + sagemaker_session=sagemaker_session, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-non-model-package-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "3.0.0"}, + ], + encrypt_inter_container_traffic=True, + enable_network_isolation=True, + environment={ + "SELF_DESTRUCT": "true", + "accept_eula": True, + "SageMakerGatedModelS3Uri": "s3://top-secret-private-" + "models-bucket/meta-training/train-meta-textgeneration-llama-2-7b.tar.gz", + }, + ) + + channels = { + "training": f"s3://{get_jumpstart_content_bucket(region)}/" + f"some-training-dataset-doesn't-matter", + } + + estimator.fit(channels) + + mock_estimator_fit.assert_called_once_with( + inputs=channels, wait=True, job_name="meta-textgeneration-llama-2-7b-8675309" + ) + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytor" + "ch-tgi-inference:2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04", + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + predictor_cls=Predictor, + endpoint_name="meta-textgeneration-llama-2-7b-8675309", + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-non-model-package-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "3.0.0"}, + ], + wait=True, + model_data_download_timeout=1200, + container_startup_health_check_timeout=1200, + role="fake role! do not use!", + enable_network_isolation=True, + model_name="meta-textgeneration-llama-2-7b-8675309", + use_compiled_model=False, + ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -580,6 +738,9 @@ def test_estimator_use_kwargs(self): deploy_kwargs=all_deploy_kwargs_used, ) + @mock.patch("sagemaker.jumpstart.factory.estimator.hyperparameters_utils.retrieve_default") + @mock.patch("sagemaker.jumpstart.factory.estimator.metric_definitions_utils.retrieve_default") + @mock.patch("sagemaker.jumpstart.factory.estimator.environment_variables.retrieve_default") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -600,11 +761,22 @@ def evaluate_estimator_workflow_with_kwargs( mock_session_model: mock.Mock, mock_is_valid_model_id: mock.Mock, mock_timestamp: mock.Mock, + mock_retrieve_default_environment_variables: mock.Mock, + mock_retrieve_metric_definitions: mock.Mock, + mock_retrieve_hyperparameters: mock.Mock, init_kwargs: Optional[dict] = None, fit_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, ): + mock_retrieve_default_environment_variables.side_effect = ( + _retrieve_default_environment_variables + ) + + mock_retrieve_metric_definitions.side_effect = _retrieve_default_training_metric_definitions + + mock_retrieve_hyperparameters.side_effect = _retrieve_default_hyperparameters + if init_kwargs is None: init_kwargs = {} @@ -684,6 +856,9 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_fit.assert_called_once_with(**expected_fit_kwargs) + mock_retrieve_default_environment_variables.assert_called_once() + mock_retrieve_metric_definitions.assert_called_once() + mock_retrieve_hyperparameters.assert_called_once() estimator.deploy(**deploy_kwargs) expected_deploy_kwargs = overwrite_dictionary( diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 3b6cbff2ad..f861ae22db 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -18,6 +18,9 @@ from mock import MagicMock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig +from sagemaker.jumpstart.artifacts.environment_variables import ( + _retrieve_default_environment_variables, +) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag @@ -167,6 +170,69 @@ def test_prepacked( ], ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") + @mock.patch("sagemaker.session.Session.create_model") + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_eula_gated_conditional_s3_prefix_metadata_model( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_create_model: mock.Mock, + mock_endpoint_from_production_variants: mock.Mock, + mock_timestamp: mock.Mock, + ): + + mock_timestamp.return_value = "1234" + + mock_is_valid_model_id.return_value = True + + model_id, _ = "gated_variant-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session.return_value = sagemaker_session + + model = JumpStartModel( + model_id=model_id, + ) + + model.deploy(accept_eula=True, instance_type="ml.p2.xlarge") + + mock_create_model.assert_called_once_with( + name="dfsdfsds-1234", + role="fake role! do not use!", + container_defs={ + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-" + "inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ModelDataSource": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/some-instance-specific/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + } + }, + }, + vpc_config=None, + enable_network_isolation=True, + tags=[ + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "gated_variant-model"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "1.0.0"}, + ], + ) + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -267,6 +333,7 @@ def test_model_use_kwargs(self): deploy_kwargs=all_deploy_kwargs_used, ) + @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -280,10 +347,13 @@ def evaluate_model_workflow_with_kwargs( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_is_valid_model_id: mock.Mock, + mock_retrieve_environment_variables: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, ): + mock_retrieve_environment_variables.side_effect = _retrieve_default_environment_variables + mock_model_deploy.return_value = default_predictor mock_is_valid_model_id.return_value = True @@ -330,6 +400,8 @@ def evaluate_model_workflow_with_kwargs( model.deploy(**deploy_kwargs) + mock_retrieve_environment_variables.assert_called_once() + expected_deploy_kwargs = overwrite_dictionary( { "initial_instance_count": 1, diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 58a8e34d25..6633ecdc23 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -42,6 +42,7 @@ BASE_MANIFEST, BASE_SPEC, ) +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @@ -525,11 +526,14 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): ) +@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs") @patch("boto3.client") -def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): +def test_jumpstart_cache_makes_correct_s3_calls( + mock_boto3_client, mock_emit_logs_based_on_model_specs +): # test get_header - mock_json = json.dumps( + mock_manifest_json = json.dumps( [ { "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", @@ -542,17 +546,17 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): ) mock_boto3_client.return_value.get_object.return_value = { "Body": botocore.response.StreamingBody( - io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + io.BytesIO(bytes(mock_manifest_json, "utf-8")), content_length=len(mock_manifest_json) ), "ETag": "etag", } mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} - bucket_name = "bucket_name" + bucket_name = get_jumpstart_content_bucket("us-west-2") client_config = botocore.config.Config(signature_version="my_signature_version") cache = JumpStartModelsCache( - s3_bucket_name=bucket_name, s3_client_config=client_config, region="my_region" + s3_bucket_name=bucket_name, s3_client_config=client_config, region="us-west-2" ) cache.get_header( model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" @@ -563,7 +567,7 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): ) mock_boto3_client.return_value.head_object.assert_not_called() - mock_boto3_client.assert_called_with("s3", region_name="my_region", config=client_config) + mock_boto3_client.assert_called_with("s3", region_name="us-west-2", config=client_config) # test get_specs. manifest already in cache, so only s3 call will be to get specs. mock_json = json.dumps(BASE_SPEC) @@ -576,9 +580,22 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): ), "ETag": "etag", } - cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" - ) + + with patch("logging.Logger.warning") as mocked_warning_log: + cache.get_specs( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + mocked_warning_log.assert_called_once_with( + "Using model 'pytorch-ic-imagenet-inception-v3-classification-4' with wildcard " + "version identifier '*'. You can pin to version '2.0.0' for more " + "stable results. Note that models may have different input/output " + "signatures after a major version upgrade." + ) + mocked_warning_log.reset_mock() + cache.get_specs( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + mocked_warning_log.assert_not_called() mock_boto3_client.return_value.get_object.assert_called_with( Bucket=bucket_name, @@ -595,10 +612,18 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear = MagicMock() cache._model_id_semantic_version_manifest_key_cache = MagicMock() cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ - JumpStartVersionedModelId( - "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" + ( + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" + ), + True, + ), + ( + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" + ), + True, ), - JumpStartVersionedModelId("tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0"), ] assert JumpStartModelHeader( @@ -616,11 +641,17 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.reset_mock() cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ - JumpStartVersionedModelId( - "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" + ( + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" + ), + True, ), - JumpStartVersionedModelId( - "tensorflow-ic-imagenet-inception-v3-classification-4", "987.0.0" + ( + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "987.0.0" + ), + True, ), ] with pytest.raises(KeyError): @@ -735,6 +766,7 @@ def test_jumpstart_local_metadata_override_header( mocked_get_json_file_and_etag_from_s3.assert_not_called() +@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs") @patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") @patch.dict( @@ -747,7 +779,10 @@ def test_jumpstart_local_metadata_override_header( @patch("sagemaker.jumpstart.cache.os.path.isdir") @patch("builtins.open") def test_jumpstart_local_metadata_override_specs( - mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock + mocked_open: Mock, + mocked_is_dir: Mock, + mocked_get_json_file_and_etag_from_s3: Mock, + mock_emit_logs_based_on_model_specs, ): mocked_open.side_effect = [ @@ -776,6 +811,7 @@ def test_jumpstart_local_metadata_override_specs( mocked_get_json_file_and_etag_from_s3.assert_not_called() +@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs") @patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") @patch.dict( @@ -791,6 +827,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock, + mocked_emit_logs_based_on_model_specs, ): model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" diff --git a/tests/unit/sagemaker/jumpstart/test_exceptions.py b/tests/unit/sagemaker/jumpstart/test_exceptions.py new file mode 100644 index 0000000000..555099a753 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_exceptions.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.jumpstart.exceptions import ( + get_wildcard_model_version_msg, + get_old_model_version_msg, +) + + +def test_get_wildcard_model_version_msg(): + assert ( + "Using model 'mother_of_all_models' with wildcard version identifier '*'. " + "You can pin to version '1.2.3' for more stable results. " + "Note that models may have different input/output signatures after a " + "major version upgrade." + == get_wildcard_model_version_msg("mother_of_all_models", "*", "1.2.3") + ) + + +def test_get_old_model_version_msg(): + assert ( + "Using model 'mother_of_all_models' with version '1.0.0'. " + "You can upgrade to version '1.2.3' to get the latest model specifications. " + "Note that models may have different input/output signatures after a major " + "version upgrade." == get_old_model_version_msg("mother_of_all_models", "1.0.0", "1.2.3") + ) diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index e269eab5a3..82e69e1d89 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -305,6 +305,16 @@ def test_jumpstart_model_header(): assert header1 == header3 +def test_use_training_model_artifact(): + specs1 = JumpStartModelSpecs(BASE_SPEC) + assert specs1.use_training_model_artifact() + specs1.gated_bucket = True + assert not specs1.use_training_model_artifact() + specs1.gated_bucket = False + specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"} + assert not specs1.use_training_model_artifact() + + def test_jumpstart_model_specs(): specs1 = JumpStartModelSpecs(BASE_SPEC) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index dac15761d9..c42012536e 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, + ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET, JUMPSTART_REGION_NAME_SET, @@ -35,6 +36,10 @@ ) from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +from mock import MagicMock + + +MOCK_CLIENT = MagicMock() def random_jumpstart_s3_uri(key): @@ -78,12 +83,12 @@ def test_get_jumpstart_gated_content_bucket_no_args(): def test_get_jumpstart_gated_content_bucket_override(): - with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE: "some-val"}): with patch("logging.Logger.info") as mocked_info_log: random_region = "random_region" assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region) mocked_info_log.assert_called_once_with( - "Using JumpStart private bucket override: 'some-val'" + "Using JumpStart gated bucket override: 'some-val'" ) @@ -883,15 +888,20 @@ def test_update_inference_tags_with_jumpstart_training_model_tags_inference(): ) -def test_jumpstart_accept_eula_logs(): +@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") +def test_jumpstart_accept_eula_logs(mock_get_manifest): + mock_get_manifest.return_value = [] + def make_accept_eula_inference_spec(*largs, **kwargs): spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") spec.hosting_eula_key = "read/the/fine/print.txt" return spec with patch("logging.Logger.info") as mocked_info_log: - utils.emit_logs_based_on_model_specs(make_accept_eula_inference_spec(), "us-east-1") - mocked_info_log.assert_called_once_with( + utils.emit_logs_based_on_model_specs( + make_accept_eula_inference_spec(), "us-east-1", MOCK_CLIENT + ) + mocked_info_log.assert_any_call( "Model '%s' requires accepting end-user license agreement (EULA). " "See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.", "pytorch-eqa-bert-base-cased", @@ -902,7 +912,10 @@ def make_accept_eula_inference_spec(*largs, **kwargs): ) -def test_jumpstart_vulnerable_model_warnings(): +@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") +def test_jumpstart_vulnerable_model_warnings(mock_get_manifest): + mock_get_manifest.return_value = [] + def make_vulnerable_inference_spec(*largs, **kwargs): spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") spec.inference_vulnerable = True @@ -910,7 +923,9 @@ def make_vulnerable_inference_spec(*largs, **kwargs): return spec with patch("logging.Logger.warning") as mocked_warning_log: - utils.emit_logs_based_on_model_specs(make_vulnerable_inference_spec(), "some-region") + utils.emit_logs_based_on_model_specs( + make_vulnerable_inference_spec(), "us-west-2", MOCK_CLIENT + ) mocked_warning_log.assert_called_once_with( "Using vulnerable JumpStart model '%s' and version '%s'.", "pytorch-eqa-bert-base-cased", @@ -918,14 +933,69 @@ def make_vulnerable_inference_spec(*largs, **kwargs): ) -def test_jumpstart_deprecated_model_warnings(): +@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") +def test_jumpstart_old_model_spec(mock_get_manifest): + + mock_get_manifest.return_value = [ + JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.1.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-in" + "ception-v3-classification-4/specs_v1.1.0.json", + } + ), + JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-" + "inception-v3-classification-4/specs_v1.0.0.json", + } + ), + ] + + with patch("logging.Logger.info") as mocked_info_log: + utils.emit_logs_based_on_model_specs( + get_spec_from_base_spec( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", version="1.0.0" + ), + "us-west-2", + MOCK_CLIENT, + ) + + mocked_info_log.assert_called_once_with( + "Using model 'tensorflow-ic-imagenet-inception-v3-classification-4' with version '1.0.0'. " + "You can upgrade to version '1.1.0' to get the latest model specifications. Note that models " + "may have different input/output signatures after a major version upgrade." + ) + + mocked_info_log.reset_mock() + + utils.emit_logs_based_on_model_specs( + get_spec_from_base_spec( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", version="1.1.0" + ), + "us-west-2", + MOCK_CLIENT, + ) + + mocked_info_log.assert_not_called() + + +@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") +def test_jumpstart_deprecated_model_warnings(mock_get_manifest): + mock_get_manifest.return_value = [] + def make_deprecated_spec(*largs, **kwargs): spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") spec.deprecated = True return spec with patch("logging.Logger.warning") as mocked_warning_log: - utils.emit_logs_based_on_model_specs(make_deprecated_spec(), "some-region") + utils.emit_logs_based_on_model_specs(make_deprecated_spec(), "us-west-2", MOCK_CLIENT) mocked_warning_log.assert_called_once_with( "Using deprecated JumpStart model 'pytorch-eqa-bert-base-cased' and version '*'." @@ -940,7 +1010,9 @@ def make_deprecated_message_spec(*largs, **kwargs): return spec with patch("logging.Logger.warning") as mocked_warning_log: - utils.emit_logs_based_on_model_specs(make_deprecated_message_spec(), "some-region") + utils.emit_logs_based_on_model_specs( + make_deprecated_message_spec(), "us-west-2", MOCK_CLIENT + ) mocked_warning_log.assert_called_once_with(deprecated_message) @@ -952,7 +1024,9 @@ def make_deprecated_warning_message_spec(*largs, **kwargs): return spec with patch("logging.Logger.warning") as mocked_warning_log: - utils.emit_logs_based_on_model_specs(make_deprecated_warning_message_spec(), "some-region") + utils.emit_logs_based_on_model_specs( + make_deprecated_warning_message_spec(), "us-west-2", MOCK_CLIENT + ) mocked_warning_log.assert_called_once_with( deprecate_warn_message, ) @@ -1117,12 +1191,6 @@ def test_is_valid_model_id_true( mock_get_manifest.assert_called_once_with( region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value ) - mock_get_model_specs.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, - model_id="bee", - version="*", - s3_client=mock_s3_client_value, - ) @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1180,13 +1248,7 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani mock_get_model_specs.reset_mock() mock_get_model_specs.return_value = Mock(training_supported=False) - self.assertFalse(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) + self.assertTrue(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) mock_get_manifest.assert_called_once_with( region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value ) - mock_get_model_specs.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, - model_id="ay", - version="*", - s3_client=mock_s3_client_value, - ) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index ff22b477a5..f868947cea 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -152,11 +152,100 @@ def test_prepare_container_def_with_model_data(): assert expected == container_def -def test_prepare_container_def_with_model_data_and_env(): +@patch("sagemaker.session.Session.endpoint_from_production_variants") +@patch("sagemaker.session.Session.create_model") +def test_prepare_container_def_with_accept_eula( + mock_create_model, mock_endpoint_from_production_variants +): + env = {"FOO": "BAR"} + model = Model(MODEL_IMAGE, MODEL_DATA, env=env, role=ROLE) + + model.deploy( + accept_eula=True, instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT + ) + + expected = { + "Image": MODEL_IMAGE, + "Environment": env, + "ModelDataSource": { + "S3DataSource": { + "CompressionType": "Gzip", + "S3DataType": "S3Object", + "S3Uri": MODEL_DATA, + "ModelAccessConfig": {"AcceptEula": True}, + } + }, + } + + container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium") + assert expected == container_def + + container_def = model.prepare_container_def() + assert expected == container_def + + +@patch("sagemaker.session.Session.endpoint_from_production_variants") +@patch("sagemaker.session.Session.create_model") +def test_prepare_container_def_with_accept_eula_s3_prefix( + mock_create_model, mock_endpoint_from_production_variants +): + env = {"FOO": "BAR"} + model_data = { + "S3DataSource": { + "S3Uri": "s3://blah-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + model = Model(MODEL_IMAGE, model_data, env=env, role=ROLE) + + model.deploy( + accept_eula=True, instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT + ) + + expected = { + "Environment": {"FOO": "BAR"}, + "Image": "mi", + "ModelDataSource": { + "S3DataSource": { + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + "S3DataType": "S3Prefix", + "S3Uri": "s3://blah-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/", + }, + }, + } + + container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium") + assert expected == container_def + + container_def = model.prepare_container_def() + assert expected == container_def + + +def test_prepare_container_def_with_model_data_and_env_s3_gzip(): env = {"FOO": "BAR"} model = Model(MODEL_IMAGE, MODEL_DATA, env=env) - expected = {"Image": MODEL_IMAGE, "Environment": env, "ModelDataUrl": MODEL_DATA} + expected = { + "Image": MODEL_IMAGE, + "Environment": env, + "ModelDataUrl": MODEL_DATA, + } + + container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium") + assert expected == container_def + + container_def = model.prepare_container_def() + assert expected == container_def + + +def test_prepare_container_def_with_model_data_and_env(): + env = {"FOO": "BAR"} + model_data = "s3://my-bucket/my-model" + model = Model(MODEL_IMAGE, model_data, env=env) + + expected = {"Image": MODEL_IMAGE, "Environment": env, "ModelDataUrl": model_data} container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium") assert expected == container_def diff --git a/tests/unit/sagemaker/utilities/test_cache.py b/tests/unit/sagemaker/utilities/test_cache.py index 10fbe45767..619468c2aa 100644 --- a/tests/unit/sagemaker/utilities/test_cache.py +++ b/tests/unit/sagemaker/utilities/test_cache.py @@ -32,14 +32,14 @@ def test_cache_retrieves_item(): ) my_cache.put(5) - assert my_cache.get(5, False) == retrieval_function(key=5) + assert my_cache.get(5, False) == (retrieval_function(key=5), True) my_cache.put(6, 7) - assert my_cache.get(6, False) == 7 + assert my_cache.get(6, False) == (7, True) assert len(my_cache) == 2 my_cache.put(5, 6) - assert my_cache.get(5, False) == 6 + assert my_cache.get(5, False) == (6, True) assert len(my_cache) == 2 with pytest.raises(KeyError): @@ -65,7 +65,7 @@ def test_cache_invalidates_old_item(): mock_datetime.now.return_value = mock_curr_time my_cache.put(5) mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5) - assert my_cache.get(5, False) == retrieval_function(key=5) + assert my_cache.get(5, False) == (retrieval_function(key=5), True) def test_cache_fetches_new_item(): @@ -80,13 +80,13 @@ def test_cache_fetches_new_item(): mock_datetime.now.return_value = mock_curr_time my_cache.put(5, 10) mock_datetime.now.return_value += datetime.timedelta(milliseconds=2) - assert my_cache.get(5) == retrieval_function(key=5) + assert my_cache.get(5) == (retrieval_function(key=5), True) with patch("datetime.datetime") as mock_datetime: mock_datetime.now.return_value = mock_curr_time my_cache.put(5, 10) mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5) - assert my_cache.get(5, False) == 10 + assert my_cache.get(5, False) == (10, True) mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.75) with pytest.raises(KeyError): my_cache.get(5, False) @@ -108,7 +108,7 @@ def test_cache_removes_old_items_once_size_limit_reached(): assert len(my_cache) == 5 with pytest.raises(KeyError): my_cache.get(1, False) - assert my_cache.get(2, False) == retrieval_function(key=2) + assert my_cache.get(2, False) == (retrieval_function(key=2), True) def test_cache_get_with_data_source_fallback(): @@ -120,7 +120,7 @@ def test_cache_get_with_data_source_fallback(): for i in range(10): val = my_cache.get(i) - assert val == retrieval_function(key=i) + assert val == (retrieval_function(key=i), False) assert len(my_cache) == 5