diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index bae28ba8a0..4c12923ca9 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -173,10 +173,10 @@ def _retrieve_image_uri( def _retrieve_model_uri( model_id: str, model_version: str, - model_scope: Optional[str], - region: Optional[str], - tolerate_vulnerable_model: bool, - tolerate_deprecated_model: bool, + model_scope: Optional[str] = None, + region: Optional[str] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -219,7 +219,11 @@ def _retrieve_model_uri( ) if model_scope == JumpStartScriptScope.INFERENCE: - model_artifact_key = model_specs.hosting_artifact_key + model_artifact_key = ( + getattr(model_specs, "hosting_prepacked_artifact_key", None) + or model_specs.hosting_artifact_key + ) + elif model_scope == JumpStartScriptScope.TRAINING: model_artifact_key = model_specs.training_artifact_key diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 4c07fad9fb..c9c6e1fbe5 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -293,6 +293,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_vulnerabilities", "deprecated", "metrics", + "hosting_prepacked_artifact_key", ] def __init__(self, spec: Dict[str, Any]): @@ -330,6 +331,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"] self.deprecated: bool = bool(json_obj["deprecated"]) self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None) + self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get( + "hosting_prepacked_artifact_key", None + ) if self.training_supported: self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs( diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 8c6bca1a48..ca61398382 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -12,6 +12,91 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +SPECIAL_MODEL_SPECS_DICT = { + "huggingface-text2text-flan-t5-xxl-fp16": { + "model_id": "huggingface-text2text-flan-t5-xxl-fp16", + "url": "https://huggingface.co/google/flan-t5-xxl", + "version": "1.0.0", + "min_sdk_version": "2.130.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.12.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17.0", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-" + "text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.0", + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.16.0", + "bitsandbytes==0.37.0", + "filelock==3.9.0", + "huggingface-hub==0.12.0", + "regex==2022.7.9", + "tokenizers==0.13.2", + "transformers==4.26.0", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + }, + {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "text", + "default": "1", + "scope": "container", + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + } +} + PROTOTYPICAL_MODEL_SPECS_DICT = { "pytorch-eqa-bert-base-cased": { "model_id": "pytorch-eqa-bert-base-cased", @@ -1093,6 +1178,7 @@ "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": None, "hyperparameters": [ { "name": "epochs", diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 7b1fc45aeb..d65e05650c 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -29,6 +29,7 @@ BASE_MANIFEST, BASE_SPEC, BASE_HEADER, + SPECIAL_MODEL_SPECS_DICT, ) @@ -92,6 +93,18 @@ def get_prototype_model_spec( return specs +def get_special_model_spec( + region: str = None, model_id: str = None, version: str = None +) -> JumpStartModelSpecs: + """This function mocks cache accessor functions. For this mock, + we only retrieve model specs based on the model ID. This is reserved + for special specs. + """ + + specs = JumpStartModelSpecs(SPECIAL_MODEL_SPECS_DICT[model_id]) + return specs + + def get_spec_from_base_spec( _obj: JumpStartModelsCache = None, region: str = None, diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_combined_artifact.py b/tests/unit/sagemaker/model_uris/jumpstart/test_combined_artifact.py new file mode 100644 index 0000000000..92eac5a342 --- /dev/null +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_combined_artifact.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import model_uris + +from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_combined_artifacts(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_special_model_spec + + model_id_combined_model_artifact = "huggingface-text2text-flan-t5-xxl-fp16" + + uri = model_uris.retrieve( + region="us-west-2", + model_scope="inference", + model_id=model_id_combined_model_artifact, + model_version="*", + ) + assert ( + uri == "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/" + "prepack/v1.0.0/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz" + )