diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 2713f88c04..ad6322365f 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -36,3 +36,4 @@ onnx>=1.15.0 nbformat>=5.9,<6 accelerate>=0.24.1,<=0.27.0 schema==0.7.5 +tensorflow>=2.1,<=2.16 diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 06b3d70aeb..52268ea40c 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -29,6 +29,7 @@ from sagemaker.serializers import NumpySerializer, TorchTensorSerializer from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.builder.tf_serving_builder import TensorflowServing from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode from sagemaker.serve.mode.local_container_mode import LocalContainerMode @@ -59,6 +60,7 @@ from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException +from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( _get_gpu_info, @@ -89,12 +91,13 @@ ModelServer.TORCHSERVE, ModelServer.TRITON, ModelServer.DJL_SERVING, + ModelServer.TENSORFLOW_SERVING, } -# pylint: disable=attribute-defined-outside-init, disable=E1101 +# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901 @dataclass -class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): +class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing): """Class that builds a deployable model. Args: @@ -493,6 +496,12 @@ def _model_builder_register_wrapper(self, *args, **kwargs): self.pysdk_model.model_package_arn = new_model_package.model_package_arn new_model_package.deploy = self._model_builder_deploy_model_package_wrapper self.model_package = new_model_package + if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT: + _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], + s3_upload_path=self.s3_upload_path, + sagemaker_session=self.sagemaker_session, + ) return new_model_package def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs): @@ -551,12 +560,19 @@ def _model_builder_deploy_wrapper( if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - return self._original_deploy( + predictor = self._original_deploy( *args, instance_type=instance_type, initial_instance_count=initial_instance_count, **kwargs, ) + if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT: + _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], + s3_upload_path=self.s3_upload_path, + sagemaker_session=self.sagemaker_session, + ) + return predictor def _overwrite_mode_in_deploy(self, overwrite_mode: str): """Mode overwritten by customer during model.deploy()""" @@ -728,7 +744,7 @@ def build( # pylint: disable=R0911 " for production at this moment." ) self._initialize_for_mlflow() - _validate_input_for_mlflow(self.model_server) + _validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR")) if isinstance(self.model, str): model_task = None @@ -767,6 +783,9 @@ def build( # pylint: disable=R0911 if self.model_server == ModelServer.TRITON: return self._build_for_triton() + if self.model_server == ModelServer.TENSORFLOW_SERVING: + return self._build_for_tensorflow_serving() + raise ValueError("%s model server is not supported" % self.model_server) def save( diff --git a/src/sagemaker/serve/builder/tf_serving_builder.py b/src/sagemaker/serve/builder/tf_serving_builder.py new file mode 100644 index 0000000000..42c548f4e4 --- /dev/null +++ b/src/sagemaker/serve/builder/tf_serving_builder.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds mixin logic to support deployment of Model ID""" +from __future__ import absolute_import +import logging +import os +from pathlib import Path +from abc import ABC, abstractmethod + +from sagemaker import Session +from sagemaker.serve.detector.pickler import save_pkl +from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving +from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor + +logger = logging.getLogger(__name__) + +_TF_SERVING_MODEL_BUILDER_ENTRY_POINT = "inference.py" +_CODE_FOLDER = "code" + + +# pylint: disable=attribute-defined-outside-init, disable=E1101 +class TensorflowServing(ABC): + """TensorflowServing build logic for ModelBuilder()""" + + def __init__(self): + self.model = None + self.serve_settings = None + self.sagemaker_session = None + self.model_path = None + self.dependencies = None + self.modes = None + self.mode = None + self.model_server = None + self.image_uri = None + self._is_custom_image_uri = False + self.image_config = None + self.vpc_config = None + self._original_deploy = None + self.secret_key = None + self.engine = None + self.pysdk_model = None + self.schema_builder = None + self.env_vars = None + + @abstractmethod + def _prepare_for_mode(self): + """Prepare model artifacts based on mode.""" + + @abstractmethod + def _get_client_translators(self): + """Set up client marshaller based on schema builder.""" + + def _save_schema_builder(self): + """Save schema builder for tensorflow serving.""" + if not os.path.exists(self.model_path): + os.makedirs(self.model_path) + + code_path = Path(self.model_path).joinpath("code") + save_pkl(code_path, self.schema_builder) + + def _get_tensorflow_predictor( + self, endpoint_name: str, sagemaker_session: Session + ) -> TensorFlowPredictor: + """Creates a TensorFlowPredictor object""" + serializer, deserializer = self._get_client_translators() + + return TensorFlowPredictor( + endpoint_name=endpoint_name, + sagemaker_session=sagemaker_session, + serializer=serializer, + deserializer=deserializer, + ) + + def _validate_for_tensorflow_serving(self): + """Validate for tensorflow serving""" + if not getattr(self, "_is_mlflow_model", False): + raise ValueError("Tensorflow Serving is currently only supported for mlflow models.") + + def _create_tensorflow_model(self): + """Creates a TensorFlow model object""" + self.pysdk_model = TensorFlowModel( + image_uri=self.image_uri, + image_config=self.image_config, + vpc_config=self.vpc_config, + model_data=self.s3_upload_path, + role=self.serve_settings.role_arn, + env=self.env_vars, + sagemaker_session=self.sagemaker_session, + predictor_cls=self._get_tensorflow_predictor, + ) + + self.pysdk_model.mode = self.mode + self.pysdk_model.modes = self.modes + self.pysdk_model.serve_settings = self.serve_settings + + self._original_deploy = self.pysdk_model.deploy + self.pysdk_model.deploy = self._model_builder_deploy_wrapper + self._original_register = self.pysdk_model.register + self.pysdk_model.register = self._model_builder_register_wrapper + self.model_package = None + return self.pysdk_model + + def _build_for_tensorflow_serving(self): + """Build the model for Tensorflow Serving""" + self._validate_for_tensorflow_serving() + self._save_schema_builder() + + if not self.image_uri: + raise ValueError("image_uri is not set for tensorflow serving") + + self.secret_key = prepare_for_tf_serving( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + ) + + self._prepare_for_mode() + + return self._create_tensorflow_model() diff --git a/src/sagemaker/serve/mode/local_container_mode.py b/src/sagemaker/serve/mode/local_container_mode.py index 362a3804de..f940e2959c 100644 --- a/src/sagemaker/serve/mode/local_container_mode.py +++ b/src/sagemaker/serve/mode/local_container_mode.py @@ -11,6 +11,7 @@ import docker from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serve.utils.logging_agent import pull_logs @@ -34,7 +35,12 @@ class LocalContainerMode( - LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalMultiModelServer + LocalTorchServe, + LocalDJLServing, + LocalTritonServer, + LocalTgiServing, + LocalMultiModelServer, + LocalTensorflowServing, ): """A class that holds methods to deploy model to a container in local environment""" @@ -141,6 +147,15 @@ def create_server( env_vars=env_vars if env_vars else self.env_vars, ) self._ping_container = self._multi_model_server_deep_ping + elif self.model_server == ModelServer.TENSORFLOW_SERVING: + self._start_tensorflow_serving( + client=self.client, + image=image, + model_path=model_path if model_path else self.model_path, + secret_key=secret_key, + env_vars=env_vars if env_vars else self.env_vars, + ) + self._ping_container = self._tensorflow_serving_deep_ping # allow some time for container to be ready time.sleep(10) diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 0fdc425b31..24acfc6a2f 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -6,6 +6,7 @@ import logging from typing import Type +from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing from sagemaker.session import Session from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.spec.inference_spec import InferenceSpec @@ -24,6 +25,7 @@ class SageMakerEndpointMode( SageMakerDjlServing, SageMakerTgiServing, SageMakerMultiModelServer, + SageMakerTensorflowServing, ): """Holds the required method to deploy a model to a SageMaker Endpoint""" @@ -107,4 +109,13 @@ def prepare( image=image, ) + if self.model_server == ModelServer.TENSORFLOW_SERVING: + return self._upload_tensorflow_serving_artifacts( + model_path=model_path, + sagemaker_session=sagemaker_session, + secret_key=secret_key, + s3_model_data_url=s3_model_data_url, + image=image, + ) + raise ValueError("%s model server is not supported" % self.model_server) diff --git a/src/sagemaker/serve/model_format/mlflow/constants.py b/src/sagemaker/serve/model_format/mlflow/constants.py index 00ef76170c..5d2ac484b5 100644 --- a/src/sagemaker/serve/model_format/mlflow/constants.py +++ b/src/sagemaker/serve/model_format/mlflow/constants.py @@ -19,6 +19,12 @@ "py39": "1.13.1", "py310": "2.2.0", } +MODEL_PACAKGE_ARN_REGEX = ( + r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/[" r"a-zA-Z0-9\-_\/\.]+$" +) +MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$" +MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$" +S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+\/[a-zA-Z0-9\-_\/\.]*$" MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH" MLFLOW_METADATA_FILE = "MLmodel" MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt" @@ -34,8 +40,12 @@ "spark": "pyspark", "onnx": "onnxruntime", } -FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = [ # will extend to keras and tf - "sklearn", - "pytorch", - "xgboost", -] +TENSORFLOW_SAVED_MODEL_NAME = "saved_model.pb" +FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = { + "sklearn": "sklearn", + "pytorch": "pytorch", + "xgboost": "xgboost", + "tensorflow": "tensorflow", + "keras": "tensorflow", +} +FLAVORS_DEFAULT_WITH_TF_SERVING = ["keras", "tensorflow"] diff --git a/src/sagemaker/serve/model_format/mlflow/utils.py b/src/sagemaker/serve/model_format/mlflow/utils.py index c9a8093a79..b67de08d78 100644 --- a/src/sagemaker/serve/model_format/mlflow/utils.py +++ b/src/sagemaker/serve/model_format/mlflow/utils.py @@ -13,7 +13,8 @@ """Holds the util functions used for MLflow model format""" from __future__ import absolute_import -from typing import Optional, Dict, Any +from pathlib import Path +from typing import Optional, Dict, Any, Union import yaml import logging import shutil @@ -30,6 +31,8 @@ DEFAULT_PYTORCH_VERSION, MLFLOW_METADATA_FILE, MLFLOW_PIP_DEPENDENCY_FILE, + FLAVORS_DEFAULT_WITH_TF_SERVING, + TENSORFLOW_SAVED_MODEL_NAME, ) logger = logging.getLogger(__name__) @@ -44,7 +47,8 @@ def _get_default_model_server_for_mlflow(deployment_flavor: str) -> ModelServer: Returns: str: The model server chosen for given model flavor. """ - # TODO: implement real logic here based on mlflow flavor + if deployment_flavor in FLAVORS_DEFAULT_WITH_TF_SERVING: + return ModelServer.TENSORFLOW_SERVING return ModelServer.TORCHSERVE @@ -344,15 +348,16 @@ def _select_container_for_mlflow_model( f"specific DLC support. Defaulting to generic image..." ) return _get_default_image_for_mlflow(python_version, region, instance_type) - framework_version = _get_framework_version_from_requirements( - deployment_flavor, requirement_path - ) + + framework_to_use = FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT.get(deployment_flavor) + framework_version = _get_framework_version_from_requirements(framework_to_use, requirement_path) logger.info("Auto-detected deployment flavor is %s", deployment_flavor) + logger.info("Auto-detected framework to use is %s", framework_to_use) logger.info("Auto-detected framework version is %s", framework_version) casted_versions = ( - _cast_to_compatible_version(deployment_flavor, framework_version) + _cast_to_compatible_version(framework_to_use, framework_version) if framework_version else (None,) ) @@ -361,7 +366,7 @@ def _select_container_for_mlflow_model( for casted_version in casted_versions: try: image_uri = image_uris.retrieve( - framework=deployment_flavor, + framework=framework_to_use, region=region, version=casted_version, image_scope="inference", @@ -392,17 +397,60 @@ def _select_container_for_mlflow_model( ) -def _validate_input_for_mlflow(model_server: ModelServer) -> None: +def _validate_input_for_mlflow(model_server: ModelServer, deployment_flavor: str) -> None: """Validates arguments provided with mlflow models. Args: - model_server (ModelServer): Model server used for orchestrating mlflow model. + - deployment_flavor (str): The flavor mlflow model will be deployed with. Raises: - ValueError: If model server is not torchserve. """ - if model_server != ModelServer.TORCHSERVE: + if model_server != ModelServer.TORCHSERVE and model_server != ModelServer.TENSORFLOW_SERVING: raise ValueError( f"{model_server} is currently not supported for MLflow Model. " f"Please choose another model server." ) + if ( + model_server == ModelServer.TENSORFLOW_SERVING + and deployment_flavor not in FLAVORS_DEFAULT_WITH_TF_SERVING + ): + raise ValueError( + "Tensorflow Serving is currently only supported for the following " + "deployment flavors: {}".format(FLAVORS_DEFAULT_WITH_TF_SERVING) + ) + + +def _get_saved_model_path_for_tensorflow_and_keras_flavor(model_path: str) -> Optional[str]: + """Recursively searches for tensorflow saved model. + + Args: + model_path (str): The root directory to start the search from. + + Returns: + Optional[str]: The absolute path to the directory containing 'saved_model.pb'. + """ + for dirpath, dirnames, filenames in os.walk(model_path): + if TENSORFLOW_SAVED_MODEL_NAME in filenames: + return os.path.abspath(dirpath) + + return None + + +def _move_contents(src_dir: Union[str, Path], dest_dir: Union[str, Path]) -> None: + """Moves all contents of a source directory to a specified destination directory. + + Args: + src_dir (Union[str, Path]): The path to the source directory. + dest_dir (Union[str, Path]): The path to the destination directory. + + """ + _src_dir = Path(os.path.normpath(src_dir)) + _dest_dir = Path(os.path.normpath(dest_dir)) + + _dest_dir.mkdir(parents=True, exist_ok=True) + + for item in _src_dir.iterdir(): + _dest_path = _dest_dir / item.name + shutil.move(str(item), str(_dest_path)) diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/__init__.py b/src/sagemaker/serve/model_server/tensorflow_serving/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/inference.py b/src/sagemaker/serve/model_server/tensorflow_serving/inference.py new file mode 100644 index 0000000000..928278e3c6 --- /dev/null +++ b/src/sagemaker/serve/model_server/tensorflow_serving/inference.py @@ -0,0 +1,147 @@ +"""This module is for SageMaker inference.py.""" + +from __future__ import absolute_import +import os +import io +import json +import cloudpickle +import shutil +import platform +from pathlib import Path +from sagemaker.serve.validations.check_integrity import perform_integrity_check +import logging + +logger = logging.getLogger(__name__) + +schema_builder = None +SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs") +SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl") +METADATA_PATH = Path(__file__).parent.joinpath("metadata.json") + + +def input_handler(data, context): + """Pre-process request input before it is sent to TensorFlow Serving REST API + + Args: + data (obj): the request data, in format of dict or string + context (Context): an object containing request and configuration details + Returns: + (dict): a JSON-serializable dict that contains request body and headers + """ + read_data = data.read() + deserialized_data = None + try: + if hasattr(schema_builder, "custom_input_translator"): + deserialized_data = schema_builder.custom_input_translator.deserialize( + io.BytesIO(read_data), context.request_content_type + ) + else: + deserialized_data = schema_builder.input_deserializer.deserialize( + io.BytesIO(read_data), context.request_content_type + ) + except Exception as e: + logger.error("Encountered error: %s in deserialize_request." % e) + raise Exception("Encountered error in deserialize_request.") from e + + try: + return json.dumps({"instances": _convert_for_serialization(deserialized_data)}) + except Exception as e: + logger.error( + "Encountered error: %s in deserialize_request. " + "Deserialized data is not json serializable. " % e + ) + raise Exception("Encountered error in deserialize_request.") from e + + +def output_handler(data, context): + """Post-process TensorFlow Serving output before it is returned to the client. + + Args: + data (obj): the TensorFlow serving response + context (Context): an object containing request and configuration details + Returns: + (bytes, string): data to return to client, response content type + """ + if data.status_code != 200: + raise ValueError(data.content.decode("utf-8")) + + response_content_type = context.accept_header + prediction = data.content + try: + prediction_dict = json.loads(prediction.decode("utf-8")) + if hasattr(schema_builder, "custom_output_translator"): + return ( + schema_builder.custom_output_translator.serialize( + prediction_dict["predictions"], response_content_type + ), + response_content_type, + ) + else: + return schema_builder.output_serializer.serialize(prediction), response_content_type + except Exception as e: + logger.error("Encountered error: %s in serialize_response." % e) + raise Exception("Encountered error in serialize_response.") from e + + +def _run_preflight_diagnostics(): + _py_vs_parity_check() + _pickle_file_integrity_check() + + +def _py_vs_parity_check(): + container_py_vs = platform.python_version() + local_py_vs = os.getenv("LOCAL_PYTHON") + + if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: + logger.warning( + f"The local python version {local_py_vs} differs from the python version " + f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior" + ) + + +def _pickle_file_integrity_check(): + with open(SERVE_PATH, "rb") as f: + buffer = f.read() + + perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH) + + +def _set_up_schema_builder(): + """Sets up the schema_builder object.""" + global schema_builder + with open(SERVE_PATH, "rb") as f: + schema_builder = cloudpickle.load(f) + + +def _set_up_shared_libs(): + """Sets up the shared libs path.""" + if SHARED_LIBS_DIR.exists(): + # before importing, place dynamic linked libraries in shared lib path + shutil.copytree(SHARED_LIBS_DIR, "/lib", dirs_exist_ok=True) + + +def _convert_for_serialization(deserialized_data): + """Attempt to convert non-serializable objects to a serializable form. + + Args: + deserialized_data: The object to convert. + + Returns: + The converted object if it was not originally serializable, otherwise the original object. + """ + import numpy as np + import pandas as pd + + if isinstance(deserialized_data, np.ndarray): + return deserialized_data.tolist() + elif isinstance(deserialized_data, pd.DataFrame): + return deserialized_data.to_dict(orient="list") + elif isinstance(deserialized_data, pd.Series): + return deserialized_data.tolist() + return deserialized_data + + +# on import, execute +_run_preflight_diagnostics() +_set_up_schema_builder() +_set_up_shared_libs() diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py b/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py new file mode 100644 index 0000000000..e9aa4aafff --- /dev/null +++ b/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py @@ -0,0 +1,67 @@ +"""Module for artifacts preparation for tensorflow_serving""" + +from __future__ import absolute_import +from pathlib import Path +import shutil +from typing import List, Dict, Any + +from sagemaker.serve.model_format.mlflow.utils import ( + _get_saved_model_path_for_tensorflow_and_keras_flavor, + _move_contents, +) +from sagemaker.serve.detector.dependency_manager import capture_dependencies +from sagemaker.serve.validations.check_integrity import ( + generate_secret_key, + compute_hash, +) +from sagemaker.remote_function.core.serialization import _MetaData + + +def prepare_for_tf_serving( + model_path: str, + shared_libs: List[str], + dependencies: Dict[str, Any], +) -> str: + """Prepares the model for serving. + + Args: + model_path (str): Path to the model directory. + shared_libs (List[str]): List of shared libraries. + dependencies (Dict[str, Any]): Dictionary of dependencies. + + Returns: + str: Secret key. + """ + + _model_path = Path(model_path) + if not _model_path.exists(): + _model_path.mkdir() + elif not _model_path.is_dir(): + raise Exception("model_dir is not a valid directory") + + code_dir = _model_path.joinpath("code") + code_dir.mkdir(exist_ok=True) + shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) + + shared_libs_dir = _model_path.joinpath("shared_libs") + shared_libs_dir.mkdir(exist_ok=True) + for shared_lib in shared_libs: + shutil.copy2(Path(shared_lib), shared_libs_dir) + + capture_dependencies(dependencies=dependencies, work_dir=code_dir) + + saved_model_bundle_dir = _model_path.joinpath("1") + saved_model_bundle_dir.mkdir(exist_ok=True) + mlflow_saved_model_dir = _get_saved_model_path_for_tensorflow_and_keras_flavor(model_path) + if not mlflow_saved_model_dir: + raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.") + _move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir) + + secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: + buffer = f.read() + hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: + metadata.write(_MetaData(hash_value).to_json()) + + return secret_key diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/server.py b/src/sagemaker/serve/model_server/tensorflow_serving/server.py new file mode 100644 index 0000000000..2392287c61 --- /dev/null +++ b/src/sagemaker/serve/model_server/tensorflow_serving/server.py @@ -0,0 +1,139 @@ +"""Module for Local Tensorflow Server""" + +from __future__ import absolute_import + +import requests +import logging +import platform +from pathlib import Path +from sagemaker.base_predictor import PredictorBase +from sagemaker.session import Session +from sagemaker.serve.utils.exceptions import LocalModelInvocationException +from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url +from sagemaker import fw_utils +from sagemaker.serve.utils.uploader import upload +from sagemaker.local.utils import get_docker_host + +logger = logging.getLogger(__name__) + + +class LocalTensorflowServing: + """LocalTensorflowServing class.""" + + def _start_tensorflow_serving( + self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict + ): + """Starts a local tensorflow serving container. + + Args: + client: Docker client + image: Image to use + model_path: Path to the model + secret_key: Secret key to use for authentication + env_vars: Environment variables to set + """ + self.container = client.containers.run( + image, + "serve", + detach=True, + auto_remove=True, + network_mode="host", + volumes={ + Path(model_path): { + "bind": "/opt/ml/model", + "mode": "rw", + }, + }, + environment={ + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + **env_vars, + }, + ) + + def _invoke_tensorflow_serving(self, request: object, content_type: str, accept: str): + """Invokes a local tensorflow serving container. + + Args: + request: Request to send + content_type: Content type to use + accept: Accept to use + """ + try: + response = requests.post( + f"http://{get_docker_host()}:8080/invocations", + data=request, + headers={"Content-Type": content_type, "Accept": accept}, + timeout=60, # this is what SageMaker Hosting uses as timeout + ) + response.raise_for_status() + return response.content + except Exception as e: + raise Exception("Unable to send request to the local container server") from e + + def _tensorflow_serving_deep_ping(self, predictor: PredictorBase): + """Checks if the local tensorflow serving container is up and running. + + If the container is not up and running, it will raise an exception. + """ + response = None + try: + response = predictor.predict(self.schema_builder.sample_input) + return (True, response) + # pylint: disable=broad-except + except Exception as e: + if "422 Client Error: Unprocessable Entity for url" in str(e): + raise LocalModelInvocationException(str(e)) + return (False, response) + + return (True, response) + + +class SageMakerTensorflowServing: + """SageMakerTensorflowServing class.""" + + def _upload_tensorflow_serving_artifacts( + self, + model_path: str, + sagemaker_session: Session, + secret_key: str, + s3_model_data_url: str = None, + image: str = None, + ): + """Uploads the model artifacts to S3. + + Args: + model_path: Path to the model + sagemaker_session: SageMaker session + secret_key: Secret key to use for authentication + s3_model_data_url: S3 model data URL + image: Image to use + """ + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) + + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) + + env_vars = { + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_REGION": sagemaker_session.boto_region_name, + "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + } + return s3_upload_path, env_vars diff --git a/src/sagemaker/serve/utils/lineage_constants.py b/src/sagemaker/serve/utils/lineage_constants.py new file mode 100644 index 0000000000..51be20739f --- /dev/null +++ b/src/sagemaker/serve/utils/lineage_constants.py @@ -0,0 +1,28 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds constants used for lineage support""" +from __future__ import absolute_import + + +LINEAGE_POLLER_INTERVAL_SECS = 15 +LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120 +MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData" +MLFLOW_S3_PATH = "S3" +MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage" +MLFLOW_RUN_ID = "MLflowRunId" +MLFLOW_LOCAL_PATH = "Local" +MLFLOW_REGISTRY_PATH = "MLflowRegistry" +ERROR = "Error" +CODE = "Code" +CONTRIBUTED_TO = "ContributedTo" +VALIDATION_EXCEPTION = "ValidationException" diff --git a/src/sagemaker/serve/utils/lineage_utils.py b/src/sagemaker/serve/utils/lineage_utils.py new file mode 100644 index 0000000000..b2e28d26c3 --- /dev/null +++ b/src/sagemaker/serve/utils/lineage_utils.py @@ -0,0 +1,277 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds the util functions used for lineage tracking""" +from __future__ import absolute_import + +import os +import time +import re +import logging +from typing import Optional, Union + +from botocore.exceptions import ClientError + +from sagemaker import Session +from sagemaker.lineage._api_types import ArtifactSummary +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association +from sagemaker.lineage.query import LineageSourceEnum +from sagemaker.serve.model_format.mlflow.constants import ( + MLFLOW_RUN_ID_REGEX, + MODEL_PACAKGE_ARN_REGEX, + S3_PATH_REGEX, + MLFLOW_REGISTRY_PATH_REGEX, +) +from sagemaker.serve.utils.lineage_constants import ( + LINEAGE_POLLER_MAX_TIMEOUT_SECS, + LINEAGE_POLLER_INTERVAL_SECS, + MLFLOW_S3_PATH, + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, + MLFLOW_LOCAL_PATH, + MLFLOW_MODEL_PACKAGE_PATH, + MLFLOW_RUN_ID, + MLFLOW_REGISTRY_PATH, + CONTRIBUTED_TO, + ERROR, + CODE, + VALIDATION_EXCEPTION, +) + +logger = logging.getLogger(__name__) + + +def _load_artifact_by_source_uri( + source_uri: str, artifact_type: str, sagemaker_session: Session +) -> Optional[ArtifactSummary]: + """Load lineage artifact by source uri + + Arguments: + source_uri (str): The s3 uri used for uploading transfomred model artifacts. + artifact_type (str): The type of the lineage artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + ArtifactSummary: The Artifact Summary for the provided S3 URI. + """ + artifacts = Artifact.list(source_uri=source_uri, sagemaker_session=sagemaker_session) + for artifact_summary in artifacts: + if artifact_summary.artifact_type == artifact_type: + return artifact_summary + return None + + +def _poll_lineage_artifact( + s3_uri: str, artifact_type: str, sagemaker_session: Session +) -> Optional[ArtifactSummary]: + """Polls lineage artifacts by s3 path. + + Arguments: + s3_uri (str): The S3 URI to check for artifacts. + artifact_type (str): The type of the lineage artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Optional[ArtifactSummary]: The artifact summary if found, otherwise None. + """ + logger.info("Polling lineage artifact for model data in %s", s3_uri) + start_time = time.time() + while time.time() - start_time < LINEAGE_POLLER_MAX_TIMEOUT_SECS: + result = _load_artifact_by_source_uri(s3_uri, artifact_type, sagemaker_session) + if result is not None: + return result + time.sleep(LINEAGE_POLLER_INTERVAL_SECS) + + +def _get_mlflow_model_path_type(mlflow_model_path: str) -> str: + """Identify mlflow model path type. + + Args: + mlflow_model_path (str): The string to be identified. + + Returns: + str: Description of what the input string is identified as. + """ + mlflow_rub_id_pattern = MLFLOW_RUN_ID_REGEX + mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX + sagemaker_arn_pattern = MODEL_PACAKGE_ARN_REGEX + s3_pattern = S3_PATH_REGEX + + if re.match(mlflow_rub_id_pattern, mlflow_model_path): + return MLFLOW_RUN_ID + if re.match(mlflow_registry_id_pattern, mlflow_model_path): + return MLFLOW_REGISTRY_PATH + if re.match(sagemaker_arn_pattern, mlflow_model_path): + return MLFLOW_MODEL_PACKAGE_PATH + if re.match(s3_pattern, mlflow_model_path): + return MLFLOW_S3_PATH + if os.path.exists(mlflow_model_path): + return MLFLOW_LOCAL_PATH + + raise ValueError(f"Invalid MLflow model path: {mlflow_model_path}") + + +def _create_mlflow_model_path_lineage_artifact( + mlflow_model_path: str, + sagemaker_session: Session, +) -> Optional[Artifact]: + """Creates a lineage artifact for the given MLflow model path. + + Args: + mlflow_model_path (str): The path to the MLflow model. + sagemaker_session (Session): The SageMaker session object. + + Returns: + Optional[Artifact]: The created lineage artifact, or None if an error occurred. + """ + _artifact_name = _get_mlflow_model_path_type(mlflow_model_path) + properties = dict( + model_builder_input_model_data_type=_artifact_name, + ) + try: + return Artifact.create( + source_uri=mlflow_model_path, + artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, + artifact_name=_artifact_name, + properties=properties, + sagemaker_session=sagemaker_session, + ) + except ClientError as e: + if e.response[ERROR][CODE] == VALIDATION_EXCEPTION: + logger.info("Artifact already exists") + else: + logger.warning("Failed to create mlflow model path lineage artifact: %s", e) + raise e + + +def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( + mlflow_model_path: str, + sagemaker_session: Session, +) -> Optional[Union[Artifact, ArtifactSummary]]: + """Retrieves an existing artifact for the given MLflow model path or + + creates a new one if it doesn't exist. + + Args: + mlflow_model_path (str): The path to the MLflow model. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + + Returns: + Optional[Union[Artifact, ArtifactSummary]]: The existing or newly created artifact, + or None if an error occurred. + """ + _loaded_artifact = _load_artifact_by_source_uri( + mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + ) + if _loaded_artifact is not None: + return _loaded_artifact + return _create_mlflow_model_path_lineage_artifact( + mlflow_model_path, + sagemaker_session, + ) + + +def _add_association_between_artifacts( + mlflow_model_path_artifact_arn: str, + autogenerated_model_data_artifact_arn: str, + sagemaker_session: Session, +) -> None: + """Add association between mlflow model path artifact and autogenerated model data artifact. + + Arguments: + mlflow_model_path_artifact_arn (str): The mlflow model path artifact. + autogenerated_model_data_artifact_arn (str): The autogenerated model data artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _association_type = CONTRIBUTED_TO + _source_arn = mlflow_model_path_artifact_arn + _destination_arn = autogenerated_model_data_artifact_arn + try: + logger.info( + "Adding association with source_arn: " + "%s, destination_arn: %s and association_type: %s.", + _source_arn, + _destination_arn, + _association_type, + ) + Association.create( + source_arn=_source_arn, + destination_arn=_destination_arn, + association_type=_association_type, + sagemaker_session=sagemaker_session, + ) + except ClientError as e: + if e.response[ERROR][CODE] == VALIDATION_EXCEPTION: + logger.info("Association already exists") + else: + raise e + + +def _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path: str, + s3_upload_path: str, + sagemaker_session: Session, +) -> None: + """Maintains lineage tracking for an MLflow model by creating or retrieving artifacts. + + Args: + mlflow_model_path (str): The path to the MLflow model. + s3_upload_path (str): The S3 path where the transformed model data is uploaded. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + artifact_for_transformed_model_data = _poll_lineage_artifact( + s3_uri=s3_upload_path, + artifact_type=LineageSourceEnum.MODEL_DATA.value, + sagemaker_session=sagemaker_session, + ) + if artifact_for_transformed_model_data: + mlflow_model_artifact = ( + _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( + mlflow_model_path=mlflow_model_path, + sagemaker_session=sagemaker_session, + ) + ) + if mlflow_model_artifact: + _mlflow_model_artifact_arn = ( + mlflow_model_artifact.artifact_arn + ) # pylint: disable=E1101, disable=C0301 + _artifact_for_transformed_model_data_arn = ( + artifact_for_transformed_model_data.artifact_arn + ) # pylint: disable=C0301 + _add_association_between_artifacts( + mlflow_model_path_artifact_arn=_mlflow_model_artifact_arn, + autogenerated_model_data_artifact_arn=_artifact_for_transformed_model_data_arn, + sagemaker_session=sagemaker_session, + ) + else: + logger.warning( + "Unable to add association between autogenerated lineage " + "artifact for transformed model data and mlflow model path" + " lineage artifacts." + ) + else: + logger.warning( + "Lineage artifact for transformed model data is not auto-created within " + "%s seconds, skipping creation of lineage artifact for mlflow model path", + LINEAGE_POLLER_MAX_TIMEOUT_SECS, + ) diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py index e0ff8f8ee1..866167c2c6 100644 --- a/src/sagemaker/serve/utils/predictors.py +++ b/src/sagemaker/serve/utils/predictors.py @@ -209,6 +209,47 @@ def delete_predictor(self): self._mode_obj.destroy_server() +class TensorflowServingLocalPredictor(PredictorBase): + """Lightweight predictor for local deployment in LOCAL_CONTAINER modes""" + + # TODO: change mode_obj to union of IN_PROCESS and LOCAL_CONTAINER objs + def __init__( + self, + mode_obj: Type[LocalContainerMode], + serializer=IdentitySerializer(), + deserializer=BytesDeserializer(), + ): + self._mode_obj = mode_obj + self.serializer = serializer + self.deserializer = deserializer + + def predict(self, data): + """Placeholder docstring""" + return self.deserializer.deserialize( + io.BytesIO( + self._mode_obj._invoke_tensorflow_serving( + self.serializer.serialize(data), + self.content_type, + self.accept[0], + ) + ) + ) + + @property + def content_type(self): + """The MIME type of the data sent to the inference endpoint.""" + return self.serializer.CONTENT_TYPE + + @property + def accept(self): + """The content type(s) that are expected from the inference endpoint.""" + return self.deserializer.ACCEPT + + def delete_predictor(self): + """Shut down and remove the container that you created in LOCAL_CONTAINER mode""" + self._mode_obj.destroy_server() + + def _get_local_mode_predictor( model_server: ModelServer, mode_obj: Type[LocalContainerMode], @@ -223,6 +264,11 @@ def _get_local_mode_predictor( if model_server == ModelServer.TRITON: return TritonLocalPredictor(mode_obj=mode_obj) + if model_server == ModelServer.TENSORFLOW_SERVING: + return TensorflowServingLocalPredictor( + mode_obj=mode_obj, serializer=serializer, deserializer=deserializer + ) + raise ValueError("%s model server is not supported yet!" % model_server) diff --git a/tests/data/serve_resources/mlflow/tensorflow/MLmodel b/tests/data/serve_resources/mlflow/tensorflow/MLmodel new file mode 100644 index 0000000000..f00412149d --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/MLmodel @@ -0,0 +1,17 @@ +artifact_path: model +flavors: + python_function: + env: + conda: conda.yaml + virtualenv: python_env.yaml + loader_module: mlflow.tensorflow + python_version: 3.10.13 + tensorflow: + code: null + model_type: tf2-module + saved_model_dir: tf2model +mlflow_version: 2.11.1 +model_size_bytes: 23823 +model_uuid: 40d2323944294fce898d8693455f60e8 +run_id: 592132312fb84935b201de2c027c54c6 +utc_time_created: '2024-04-01 19:47:15.396517' diff --git a/tests/data/serve_resources/mlflow/tensorflow/conda.yaml b/tests/data/serve_resources/mlflow/tensorflow/conda.yaml new file mode 100644 index 0000000000..90d8c300a0 --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/conda.yaml @@ -0,0 +1,11 @@ +channels: +- conda-forge +dependencies: +- python=3.10.13 +- pip<=23.3.1 +- pip: + - mlflow==2.11.1 + - cloudpickle==2.2.1 + - numpy==1.26.4 + - tensorflow==2.16.1 +name: mlflow-env diff --git a/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml b/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml new file mode 100644 index 0000000000..9e09178b6c --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml @@ -0,0 +1,7 @@ +python: 3.10.13 +build_dependencies: +- pip==23.3.1 +- setuptools==68.2.2 +- wheel==0.41.2 +dependencies: +- -r requirements.txt diff --git a/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta b/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta new file mode 100644 index 0000000000..5423c0e6c7 --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta @@ -0,0 +1,2 @@ +model_name: model +model_version: '2' diff --git a/tests/data/serve_resources/mlflow/tensorflow/requirements.txt b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt new file mode 100644 index 0000000000..2ff55b8e87 --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt @@ -0,0 +1,4 @@ +mlflow==2.11.1 +cloudpickle==2.2.1 +numpy==1.26.4 +tensorflow==2.16.1 diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb b/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb new file mode 100644 index 0000000000..ba1e240ba5 --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb @@ -0,0 +1 @@ +´™ÊÁ´¼ÜÌÌĽð˜Œëón”™¯’/ öÊ¢„ÑòÁ„‰(Œ½£‘Èð32 \ No newline at end of file diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb b/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb new file mode 100644 index 0000000000..e48f2b59cc Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb differ diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001 b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000..575da96282 Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001 differ diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index new file mode 100644 index 0000000000..57646ac350 Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index differ diff --git a/tests/integ/sagemaker/serve/constants.py b/tests/integ/sagemaker/serve/constants.py index 794f7333a3..d5e7a56f83 100644 --- a/tests/integ/sagemaker/serve/constants.py +++ b/tests/integ/sagemaker/serve/constants.py @@ -32,6 +32,7 @@ DATA_DIR, "serve_resources", "mlflow", "pytorch" ) XGBOOST_MLFLOW_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "mlflow", "xgboost") +TENSORFLOW_MLFLOW_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "mlflow", "tensorflow") TF_EFFICIENT_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "tensorflow") HF_DIR = os.path.join(DATA_DIR, "serve_resources", "hf") diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py index e7ebd9c5bf..e6beb76d6e 100644 --- a/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py @@ -19,6 +19,8 @@ import io import numpy as np +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association from sagemaker.s3 import S3Uploader from sagemaker.serve.builder.model_builder import ModelBuilder, Mode from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator @@ -35,6 +37,10 @@ from tests.integ.utils import cleanup_model_resources import logging +from sagemaker.serve.utils.lineage_constants import ( + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, +) + logger = logging.getLogger(__name__) ROLE_NAME = "SageMakerRole" @@ -205,6 +211,19 @@ def test_happy_pytorch_sagemaker_endpoint_with_torch_serve( predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) logger.info("Endpoint successfully deployed.") predictor.predict(test_image) + model_data_artifact = None + for artifact in Artifact.list( + source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session + ): + model_data_artifact = artifact + for association in Association.list( + destination_arn=model_data_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ): + assert ( + association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE + ) + break except Exception as e: caught_ex = e finally: @@ -214,9 +233,4 @@ def test_happy_pytorch_sagemaker_endpoint_with_torch_serve( endpoint_name=model.endpoint_name, ) if caught_ex: - logger.exception(caught_ex) - ignore_if_worker_dies = "Worker died." in str(caught_ex) - # https://github.com/pytorch/serve/issues/3032 - assert ( - ignore_if_worker_dies - ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test" + raise caught_ex diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py new file mode 100644 index 0000000000..c25cbd7e18 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py @@ -0,0 +1,175 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import io +import numpy as np + +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association +from sagemaker.s3 import S3Uploader +from sagemaker.serve.builder.model_builder import ModelBuilder, Mode +from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator +import tensorflow as tf +from sklearn.datasets import fetch_california_housing + + +from tests.integ.sagemaker.serve.constants import ( + TENSORFLOW_MLFLOW_RESOURCE_DIR, + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, + PYTHON_VERSION_IS_NOT_310, +) +from tests.integ.timeout import timeout +from tests.integ.utils import cleanup_model_resources +import logging + +from sagemaker.serve.utils.lineage_constants import ( + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, +) + +logger = logging.getLogger(__name__) + +ROLE_NAME = "SageMakerRole" + + +@pytest.fixture +def test_data(): + dataset = fetch_california_housing(as_frame=True)["frame"] + dataset = dataset.dropna() + dataset_tf = tf.convert_to_tensor(dataset, dtype=tf.float32) + dataset_tf = dataset_tf[:50] + x_test, y_test = dataset_tf[:, :-1], dataset_tf[:, -1] + return x_test, y_test + + +@pytest.fixture +def custom_request_translator(): + class MyRequestTranslator(CustomPayloadTranslator): + def serialize_payload_to_bytes(self, payload: object) -> bytes: + return self._convert_numpy_to_bytes(payload) + + def deserialize_payload_from_stream(self, stream) -> object: + np_array = np.load(io.BytesIO(stream.read())) + return np_array + + def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes: + buffer = io.BytesIO() + np.save(buffer, np_array) + return buffer.getvalue() + + return MyRequestTranslator() + + +@pytest.fixture +def custom_response_translator(): + class MyResponseTranslator(CustomPayloadTranslator): + def serialize_payload_to_bytes(self, payload: object) -> bytes: + import numpy as np + + return self._convert_numpy_to_bytes(np.array(payload)) + + def deserialize_payload_from_stream(self, stream) -> object: + import tensorflow as tf + + return tf.convert_to_tensor(np.load(io.BytesIO(stream.read()))) + + def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes: + buffer = io.BytesIO() + np.save(buffer, np_array) + return buffer.getvalue() + + return MyResponseTranslator() + + +@pytest.fixture +def tensorflow_schema_builder(custom_request_translator, custom_response_translator, test_data): + input_data, output_data = test_data + return SchemaBuilder( + sample_input=input_data, + sample_output=output_data, + input_translator=custom_request_translator, + output_translator=custom_response_translator, + ) + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", +) +def test_happy_tensorflow_sagemaker_endpoint_with_tensorflow_serving( + sagemaker_session, + tensorflow_schema_builder, + cpu_instance_type, + test_data, +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + + model_artifacts_uri = "s3://{}/{}/{}/{}".format( + sagemaker_session.default_bucket(), + "model_builder_integ_test", + "mlflow", + "tensorflow", + ) + + model_path = S3Uploader.upload( + local_path=TENSORFLOW_MLFLOW_RESOURCE_DIR, + desired_s3_uri=model_artifacts_uri, + sagemaker_session=sagemaker_session, + ) + + model_builder = ModelBuilder( + mode=Mode.SAGEMAKER_ENDPOINT, + schema_builder=tensorflow_schema_builder, + role_arn=role_arn, + sagemaker_session=sagemaker_session, + model_metadata={"MLFLOW_MODEL_PATH": model_path}, + ) + + model = model_builder.build(sagemaker_session=sagemaker_session) + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + test_x, _ = test_data + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) + logger.info("Endpoint successfully deployed.") + predictor.predict(test_x) + model_data_artifact = None + for artifact in Artifact.list( + source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session + ): + model_data_artifact = artifact + for association in Association.list( + destination_arn=model_data_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ): + assert ( + association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE + ) + break + + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + raise caught_ex diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py index 5a73942afe..7b47440a97 100644 --- a/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py @@ -16,6 +16,8 @@ import io import numpy as np +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association from sagemaker.s3 import S3Uploader from sagemaker.serve.builder.model_builder import ModelBuilder, Mode from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator @@ -32,6 +34,10 @@ from tests.integ.utils import cleanup_model_resources import logging +from sagemaker.serve.utils.lineage_constants import ( + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, +) + logger = logging.getLogger(__name__) ROLE_NAME = "SageMakerRole" @@ -187,6 +193,19 @@ def test_happy_xgboost_sagemaker_endpoint_with_torch_serve( predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) logger.info("Endpoint successfully deployed.") predictor.predict(test_x) + model_data_artifact = None + for artifact in Artifact.list( + source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session + ): + model_data_artifact = artifact + for association in Association.list( + destination_arn=model_data_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ): + assert ( + association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE + ) + break except Exception as e: caught_ex = e finally: @@ -196,9 +215,4 @@ def test_happy_xgboost_sagemaker_endpoint_with_torch_serve( endpoint_name=model.endpoint_name, ) if caught_ex: - logger.exception(caught_ex) - ignore_if_worker_dies = "Worker died." in str(caught_ex) - # https://github.com/pytorch/serve/issues/3032 - assert ( - ignore_if_worker_dies - ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test" + raise caught_ex diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 1d199b7401..3ffbdd7c03 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -21,6 +21,7 @@ from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException +from sagemaker.serve.utils.predictors import TensorflowServingLocalPredictor from sagemaker.serve.utils.types import ModelServer from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG @@ -52,6 +53,7 @@ ModelServer.TORCHSERVE, ModelServer.TRITON, ModelServer.DJL_SERVING, + ModelServer.TENSORFLOW_SERVING, } mock_session = MagicMock() @@ -1677,6 +1679,7 @@ def test_build_task_override_with_invalid_model_provided( model_builder.build(sagemaker_session=mock_session) @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.model_builder._maintain_lineage_tracking_for_mlflow_model") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -1705,6 +1708,7 @@ def test_build_mlflow_model_local_input_happy( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_lineage_tracking, ): # setup mocks @@ -1750,6 +1754,85 @@ def test_build_mlflow_model_local_input_happy( self.assertEqual(build_result.serve_settings, mock_setting_object) self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "sklearn") + build_result.deploy( + initial_instance_count=1, instance_type=mock_instance_type, mode=Mode.SAGEMAKER_ENDPOINT + ) + mock_lineage_tracking.assert_called_once() + + @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") + @patch("sagemaker.serve.builder.model_builder.save_pkl") + @patch("sagemaker.serve.builder.model_builder._copy_directory_contents") + @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path") + @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata") + @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode") + @patch("sagemaker.serve.builder.model_builder.Model") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("os.path.isfile", return_value=True) + @patch("os.path.exists") + def test_build_mlflow_model_local_input_happy_flavor_server_mismatch( + self, + mock_path_exists, + mock_is_file, + mock_open, + mock_sdk_model, + mock_sageMakerEndpointMode, + mock_serveSettings, + mock_detect_container, + mock_get_all_flavor_metadata, + mock_generate_mlflow_artifact_path, + mock_copy_directory_contents, + mock_save_pkl, + mock_prepare_for_torchserve, + mock_detect_fw_version, + ): + # setup mocks + + mock_detect_container.return_value = mock_image_uri + mock_get_all_flavor_metadata.return_value = {"sklearn": "some_data"} + mock_generate_mlflow_artifact_path.return_value = "some_path" + + mock_prepare_for_torchserve.return_value = mock_secret_key + + # Mock _ServeSettings + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_path_exists.side_effect = lambda path: True if path == "test_path" else False + + mock_mode = Mock() + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None + ) + mock_mode.prepare.return_value = ( + model_data, + ENV_VAR_PAIR, + ) + + updated_env_var = deepcopy(ENV_VARS) + updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "sklearn"}) + mock_model_obj = Mock() + mock_sdk_model.return_value = mock_model_obj + + mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent" + + # run + builder = ModelBuilder( + schema_builder=schema_builder, + model_metadata={"MLFLOW_MODEL_PATH": MODEL_PATH}, + model_server=ModelServer.TENSORFLOW_SERVING, + ) + with self.assertRaises(ValueError): + builder.build( + Mode.SAGEMAKER_ENDPOINT, + mock_role_arn, + mock_session, + ) + @patch("os.makedirs", Mock()) @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @@ -1899,3 +1982,240 @@ def test_build_mlflow_model_s3_input_non_mlflow_case( mock_role_arn, mock_session, ) + + @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.model_builder._maintain_lineage_tracking_for_mlflow_model") + @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving") + @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") + @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl") + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path") + @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata") + @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode") + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("os.path.exists") + def test_build_mlflow_model_s3_input_tensorflow_serving_happy( + self, + mock_path_exists, + mock_open, + mock_sdk_model, + mock_sageMakerEndpointMode, + mock_serveSettings, + mock_detect_container, + mock_get_all_flavor_metadata, + mock_generate_mlflow_artifact_path, + mock_download_s3_artifacts, + mock_save_pkl, + mock_detect_fw_version, + mock_s3_downloader, + mock_prepare_for_tf_serving, + mock_lineage_tracking, + ): + # setup mocks + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + mock_detect_container.return_value = mock_image_uri + mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"} + mock_generate_mlflow_artifact_path.return_value = "some_path" + + mock_prepare_for_tf_serving.return_value = mock_secret_key + + # Mock _ServeSettings + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_path_exists.side_effect = lambda path: True if path == "test_path" else False + + mock_mode = Mock() + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode + if inference_spec is None and model_server == ModelServer.TENSORFLOW_SERVING + else None + ) + mock_mode.prepare.return_value = ( + model_data, + ENV_VAR_PAIR, + ) + + updated_env_var = deepcopy(ENV_VARS) + updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"}) + mock_model_obj = Mock() + mock_sdk_model.return_value = mock_model_obj + + mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent" + + # run + builder = ModelBuilder( + schema_builder=schema_builder, model_metadata={"MLFLOW_MODEL_PATH": "s3://test_path/"} + ) + build_result = builder.build(sagemaker_session=mock_session) + self.assertEqual(mock_model_obj, build_result) + self.assertEqual(build_result.mode, Mode.SAGEMAKER_ENDPOINT) + self.assertEqual(build_result.modes, {str(Mode.SAGEMAKER_ENDPOINT): mock_mode}) + self.assertEqual(build_result.serve_settings, mock_setting_object) + self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "tensorflow") + + build_result.deploy( + initial_instance_count=1, instance_type=mock_instance_type, mode=Mode.SAGEMAKER_ENDPOINT + ) + mock_lineage_tracking.assert_called_once() + + @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving") + @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") + @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl") + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path") + @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata") + @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.LocalContainerMode") + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("os.path.exists") + def test_build_mlflow_model_s3_input_tensorflow_serving_local_mode_happy( + self, + mock_path_exists, + mock_open, + mock_sdk_model, + mock_local_container_mode, + mock_serveSettings, + mock_detect_container, + mock_get_all_flavor_metadata, + mock_generate_mlflow_artifact_path, + mock_download_s3_artifacts, + mock_save_pkl, + mock_detect_fw_version, + mock_s3_downloader, + mock_prepare_for_tf_serving, + ): + # setup mocks + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + mock_detect_container.return_value = mock_image_uri + mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"} + mock_generate_mlflow_artifact_path.return_value = "some_path" + + mock_prepare_for_tf_serving.return_value = mock_secret_key + + # Mock _ServeSettings + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_path_exists.side_effect = lambda path: True if path == "test_path" else False + + mock_mode = Mock() + mock_mode.prepare.side_effect = lambda: None + mock_local_container_mode.return_value = mock_mode + mock_mode.prepare.return_value = ( + model_data, + ENV_VAR_PAIR, + ) + + updated_env_var = deepcopy(ENV_VARS) + updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"}) + mock_model_obj = Mock() + mock_sdk_model.return_value = mock_model_obj + + mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent" + + # run + builder = ModelBuilder( + mode=Mode.LOCAL_CONTAINER, + schema_builder=schema_builder, + model_metadata={"MLFLOW_MODEL_PATH": "s3://test_path/"}, + ) + build_result = builder.build(sagemaker_session=mock_session) + self.assertEqual(mock_model_obj, build_result) + self.assertEqual(build_result.mode, Mode.LOCAL_CONTAINER) + self.assertEqual(build_result.modes, {str(Mode.LOCAL_CONTAINER): mock_mode}) + self.assertEqual(build_result.serve_settings, mock_setting_object) + self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "tensorflow") + + predictor = build_result.deploy() + assert isinstance(predictor, TensorflowServingLocalPredictor) + + @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving") + @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") + @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.builder.model_builder.save_pkl") + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path") + @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata") + @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode") + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("os.path.exists") + def test_build_tensorflow_serving_non_mlflow_case( + self, + mock_path_exists, + mock_open, + mock_sdk_model, + mock_sageMakerEndpointMode, + mock_serveSettings, + mock_detect_container, + mock_get_all_flavor_metadata, + mock_generate_mlflow_artifact_path, + mock_download_s3_artifacts, + mock_save_pkl, + mock_detect_fw_version, + mock_s3_downloader, + mock_prepare_for_tf_serving, + ): + mock_s3_downloader.return_value = [] + mock_detect_container.return_value = mock_image_uri + mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"} + mock_generate_mlflow_artifact_path.return_value = "some_path" + + mock_prepare_for_tf_serving.return_value = mock_secret_key + + # Mock _ServeSettings + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_path_exists.side_effect = lambda path: True if path == "test_path" else False + + mock_mode = Mock() + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode + if inference_spec is None and model_server == ModelServer.TENSORFLOW_SERVING + else None + ) + mock_mode.prepare.return_value = ( + model_data, + ENV_VAR_PAIR, + ) + + updated_env_var = deepcopy(ENV_VARS) + updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"}) + mock_model_obj = Mock() + mock_sdk_model.return_value = mock_model_obj + + mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent" + + # run + builder = ModelBuilder( + model=mock_fw_model, + schema_builder=schema_builder, + model_server=ModelServer.TENSORFLOW_SERVING, + ) + + self.assertRaisesRegex( + Exception, + "Tensorflow Serving is currently only supported for mlflow models.", + builder.build, + Mode.SAGEMAKER_ENDPOINT, + mock_role_arn, + mock_session, + ) diff --git a/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py b/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py new file mode 100644 index 0000000000..9d51b04e08 --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import MagicMock, patch + +import unittest +from pathlib import Path + +from sagemaker.serve import ModelBuilder, ModelServer + + +class TestTransformersBuilder(unittest.TestCase): + def setUp(self): + self.instance = ModelBuilder() + self.instance.model_server = ModelServer.TENSORFLOW_SERVING + self.instance.model_path = "/fake/model/path" + self.instance.image_uri = "fake_image_uri" + self.instance.s3_upload_path = "s3://bucket/path" + self.instance.serve_settings = MagicMock(role_arn="fake_role_arn") + self.instance.schema_builder = MagicMock() + self.instance.env_vars = {} + self.instance.sagemaker_session = MagicMock() + self.instance.image_config = {} + self.instance.vpc_config = {} + self.instance.modes = {} + + @patch("os.makedirs") + @patch("os.path.exists") + @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl") + def test_save_schema_builder(self, mock_save_pkl, mock_exists, mock_makedirs): + mock_exists.return_value = False + self.instance._save_schema_builder() + mock_makedirs.assert_called_once_with(self.instance.model_path) + code_path = Path(self.instance.model_path).joinpath("code") + mock_save_pkl.assert_called_once_with(code_path, self.instance.schema_builder) + + @patch("sagemaker.serve.builder.tf_serving_builder.TensorflowServing._get_client_translators") + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowPredictor") + def test_get_tensorflow_predictor(self, mock_predictor, mock_get_marshaller): + endpoint_name = "test_endpoint" + predictor = self.instance._get_tensorflow_predictor( + endpoint_name, self.instance.sagemaker_session + ) + mock_predictor.assert_called_once_with( + endpoint_name=endpoint_name, + sagemaker_session=self.instance.sagemaker_session, + serializer=self.instance.schema_builder.custom_input_translator, + deserializer=self.instance.schema_builder.custom_output_translator, + ) + self.assertEqual(predictor, mock_predictor.return_value) + + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel") + def test_create_tensorflow_model(self, mock_model): + model = self.instance._create_tensorflow_model() + mock_model.assert_called_once_with( + image_uri=self.instance.image_uri, + image_config=self.instance.image_config, + vpc_config=self.instance.vpc_config, + model_data=self.instance.s3_upload_path, + role=self.instance.serve_settings.role_arn, + env=self.instance.env_vars, + sagemaker_session=self.instance.sagemaker_session, + predictor_cls=self.instance._get_tensorflow_predictor, + ) + self.assertEqual(model, mock_model.return_value) diff --git a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py index 154b6b7d95..27c3a5280f 100644 --- a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py +++ b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import os +from pathlib import Path from unittest.mock import patch, MagicMock, mock_open import pytest @@ -21,6 +22,7 @@ from sagemaker.serve import ModelServer from sagemaker.serve.model_format.mlflow.constants import ( MLFLOW_PYFUNC, + TENSORFLOW_SAVED_MODEL_NAME, ) from sagemaker.serve.model_format.mlflow.utils import ( _get_default_model_server_for_mlflow, @@ -35,6 +37,8 @@ _select_container_for_mlflow_model, _validate_input_for_mlflow, _copy_directory_contents, + _move_contents, + _get_saved_model_path_for_tensorflow_and_keras_flavor, ) @@ -415,10 +419,15 @@ def test_select_container_for_mlflow_model_no_dlc_detected( def test_validate_input_for_mlflow(): - _validate_input_for_mlflow(ModelServer.TORCHSERVE) + _validate_input_for_mlflow(ModelServer.TORCHSERVE, "pytorch") with pytest.raises(ValueError): - _validate_input_for_mlflow(ModelServer.DJL_SERVING) + _validate_input_for_mlflow(ModelServer.DJL_SERVING, "pytorch") + + +def test_validate_input_for_mlflow_non_supported_flavor_with_tf_serving(): + with pytest.raises(ValueError): + _validate_input_for_mlflow(ModelServer.TENSORFLOW_SERVING, "pytorch") @patch("sagemaker.serve.model_format.mlflow.utils.shutil.copy2") @@ -472,3 +481,68 @@ def test_copy_directory_contents_handles_same_src_dst( mock_os_walk.assert_not_called() mock_os_makedirs.assert_not_called() mock_shutil_copy2.assert_not_called() + + +@patch("os.path.abspath") +@patch("os.walk") +def test_get_saved_model_path_found(mock_os_walk, mock_os_abspath): + mock_os_walk.return_value = [ + ("/root/folder1", ("subfolder",), ()), + ("/root/folder1/subfolder", (), (TENSORFLOW_SAVED_MODEL_NAME,)), + ] + expected_path = "/root/folder1/subfolder" + mock_os_abspath.return_value = expected_path + + # Call the function + result = _get_saved_model_path_for_tensorflow_and_keras_flavor("/root/folder1") + + # Assertions + mock_os_walk.assert_called_once_with("/root/folder1") + mock_os_abspath.assert_called_once_with("/root/folder1/subfolder") + assert result == expected_path + + +@patch("os.path.abspath") +@patch("os.walk") +def test_get_saved_model_path_not_found(mock_os_walk, mock_os_abspath): + mock_os_walk.return_value = [ + ("/root/folder2", ("subfolder",), ()), + ("/root/folder2/subfolder", (), ("not_saved_model.pb",)), + ] + + result = _get_saved_model_path_for_tensorflow_and_keras_flavor("/root/folder2") + + mock_os_walk.assert_called_once_with("/root/folder2") + mock_os_abspath.assert_not_called() + assert result is None + + +@patch("sagemaker.serve.model_format.mlflow.utils.shutil.move") +@patch("sagemaker.serve.model_format.mlflow.utils.Path.iterdir") +@patch("sagemaker.serve.model_format.mlflow.utils.Path.mkdir") +def test_move_contents_handles_same_src_dst(mock_mkdir, mock_iterdir, mock_shutil_move): + src_dir = "/fake/source/dir" + dest_dir = "/fake/source/./dir" + + mock_iterdir.return_value = [] + + _move_contents(src_dir, dest_dir) + + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + mock_shutil_move.assert_not_called() + + +@patch("sagemaker.serve.model_format.mlflow.utils.shutil.move") +@patch("sagemaker.serve.model_format.mlflow.utils.Path.iterdir") +@patch("sagemaker.serve.model_format.mlflow.utils.Path.mkdir") +def test_move_contents_with_actual_files(mock_mkdir, mock_iterdir, mock_shutil_move): + src_dir = Path("/fake/source/dir") + dest_dir = Path("/fake/destination/dir") + + file_path = src_dir / "testfile.txt" + mock_iterdir.return_value = [file_path] + + _move_contents(src_dir, dest_dir) + + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + mock_shutil_move.assert_called_once_with(str(file_path), str(dest_dir / "testfile.txt")) diff --git a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py new file mode 100644 index 0000000000..9915b19649 --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py @@ -0,0 +1,116 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import Mock, patch, mock_open +import pytest + +from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving + +MODEL_PATH = "/path/to/your/model/dir" +SHARED_LIBS = ["/path/to/shared/libs"] +DEPENDENCIES = {"dependencies": "requirements.txt"} +INFERENCE_SPEC = Mock() +IMAGE_URI = "mock_image_uri" +XGB_1P_IMAGE_URI = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.7-1" +INFERENCE_SPEC.prepare = Mock(return_value=None) + +SECRET_KEY = "secret-key" + +mock_session = Mock() + + +class PrepareForTensorflowServingTests(TestCase): + def setUp(self): + INFERENCE_SPEC.reset_mock() + + @patch("builtins.open", new_callable=mock_open, read_data=b"{}") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents") + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare." + "_get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._MetaData") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.shutil") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.Path") + def test_prepare_happy( + self, + mock_path, + mock_shutil, + mock_capture_dependencies, + mock_generate_secret_key, + mock_compute_hash, + mock_metadata, + mock_get_saved_model_path, + mock_move_contents, + mock_open, + ): + + mock_path_instance = mock_path.return_value + mock_path_instance.exists.return_value = True + mock_path_instance.joinpath.return_value = Mock() + mock_get_saved_model_path.return_value = MODEL_PATH + "/1/" + + mock_generate_secret_key.return_value = SECRET_KEY + + secret_key = prepare_for_tf_serving( + model_path=MODEL_PATH, + shared_libs=SHARED_LIBS, + dependencies=DEPENDENCIES, + ) + + mock_path_instance.mkdir.assert_not_called() + self.assertEqual(secret_key, SECRET_KEY) + + @patch("builtins.open", new_callable=mock_open, read_data=b"{}") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents") + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare." + "_get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._MetaData") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.shutil") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.Path") + def test_prepare_saved_model_not_found( + self, + mock_path, + mock_shutil, + mock_capture_dependencies, + mock_generate_secret_key, + mock_compute_hash, + mock_metadata, + mock_get_saved_model_path, + mock_move_contents, + mock_open, + ): + + mock_path_instance = mock_path.return_value + mock_path_instance.exists.return_value = True + mock_path_instance.joinpath.return_value = Mock() + mock_get_saved_model_path.return_value = None + + with pytest.raises( + ValueError, match="SavedModel is not found for Tensorflow or Keras flavor." + ): + prepare_for_tf_serving( + model_path=MODEL_PATH, + shared_libs=SHARED_LIBS, + dependencies=DEPENDENCIES, + ) diff --git a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py new file mode 100644 index 0000000000..3d3bac0935 --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py @@ -0,0 +1,100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from pathlib import PosixPath +import platform +from unittest import TestCase +from unittest.mock import Mock, patch, ANY + +import numpy as np + +from sagemaker.serve.model_server.tensorflow_serving.server import ( + LocalTensorflowServing, + SageMakerTensorflowServing, +) + +CPU_TF_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.14.1-cpu" +MODEL_PATH = "model_path" +MODEL_REPO = f"{MODEL_PATH}/1" +ENV_VAR = {"KEY": "VALUE"} +_SHM_SIZE = "2G" +PAYLOAD = np.random.rand(3, 4).astype(dtype=np.float32) +S3_URI = "s3://mock_model_data_uri" +DTYPE = "TYPE_FP32" +SECRET_KEY = "secret_key" + +INFER_RESPONSE = {"outputs": [{"name": "output_name"}]} + + +class TensorflowservingServerTests(TestCase): + def test_start_invoke_destroy_local_tensorflow_serving_server(self): + mock_container = Mock() + mock_docker_client = Mock() + mock_docker_client.containers.run.return_value = mock_container + + local_tensorflow_server = LocalTensorflowServing() + mock_schema_builder = Mock() + mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD + local_tensorflow_server.schema_builder = mock_schema_builder + + local_tensorflow_server._start_tensorflow_serving( + client=mock_docker_client, + model_path=MODEL_PATH, + secret_key=SECRET_KEY, + env_vars=ENV_VAR, + image=CPU_TF_IMAGE, + ) + + mock_docker_client.containers.run.assert_called_once_with( + CPU_TF_IMAGE, + "serve", + detach=True, + auto_remove=True, + network_mode="host", + volumes={PosixPath("model_path"): {"bind": "/opt/ml/model", "mode": "rw"}}, + environment={ + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": "secret_key", + "LOCAL_PYTHON": platform.python_version(), + "KEY": "VALUE", + }, + ) + + @patch("sagemaker.serve.model_server.tensorflow_serving.server.platform") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.upload") + def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platform): + mock_session = Mock() + mock_platform.python_version.return_value = "3.8" + mock_upload.side_effect = lambda session, repo, bucket, prefix: ( + S3_URI + if session == mock_session and repo == MODEL_PATH and bucket == "mock_model_data_uri" + else None + ) + + ( + s3_upload_path, + env_vars, + ) = SageMakerTensorflowServing()._upload_tensorflow_serving_artifacts( + model_path=MODEL_PATH, + sagemaker_session=mock_session, + secret_key=SECRET_KEY, + s3_model_data_url=S3_URI, + image=CPU_TF_IMAGE, + ) + + mock_upload.assert_called_once_with(mock_session, MODEL_PATH, "mock_model_data_uri", ANY) + self.assertEqual(s3_upload_path, S3_URI) + self.assertEqual(env_vars.get("SAGEMAKER_SERVE_SECRET_KEY"), SECRET_KEY) + self.assertEqual(env_vars.get("LOCAL_PYTHON"), "3.8") diff --git a/tests/unit/sagemaker/serve/utils/test_lineage_utils.py b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py new file mode 100644 index 0000000000..25e4fe246e --- /dev/null +++ b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py @@ -0,0 +1,374 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import call + +import pytest +from botocore.exceptions import ClientError +from mock import Mock, patch +from sagemaker import Session +from sagemaker.lineage.artifact import ArtifactSummary, Artifact +from sagemaker.lineage.query import LineageSourceEnum + +from sagemaker.serve.utils.lineage_constants import ( + MLFLOW_RUN_ID, + MLFLOW_MODEL_PACKAGE_PATH, + MLFLOW_S3_PATH, + MLFLOW_LOCAL_PATH, + LINEAGE_POLLER_MAX_TIMEOUT_SECS, + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, + CONTRIBUTED_TO, + MLFLOW_REGISTRY_PATH, +) +from sagemaker.serve.utils.lineage_utils import ( + _load_artifact_by_source_uri, + _poll_lineage_artifact, + _get_mlflow_model_path_type, + _create_mlflow_model_path_lineage_artifact, + _add_association_between_artifacts, + _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact, + _maintain_lineage_tracking_for_mlflow_model, +) + + +@patch("sagemaker.lineage.artifact.Artifact.list") +def test_load_artifact_by_source_uri(mock_artifact_list): + source_uri = "s3://mybucket/mymodel" + sagemaker_session = Mock(spec=Session) + + mock_artifact_1 = Mock(spec=ArtifactSummary) + mock_artifact_1.artifact_type = LineageSourceEnum.MODEL_DATA.value + mock_artifact_2 = Mock(spec=ArtifactSummary) + mock_artifact_2.artifact_type = LineageSourceEnum.IMAGE.value + mock_artifacts = [mock_artifact_1, mock_artifact_2] + mock_artifact_list.return_value = mock_artifacts + + result = _load_artifact_by_source_uri( + source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session + ) + + mock_artifact_list.assert_called_once_with( + source_uri=source_uri, sagemaker_session=sagemaker_session + ) + assert result == mock_artifact_1 + + +@patch("sagemaker.lineage.artifact.Artifact.list") +def test_load_artifact_by_source_uri_no_match(mock_artifact_list): + source_uri = "s3://mybucket/mymodel" + sagemaker_session = Mock(spec=Session) + + mock_artifact_1 = Mock(spec=ArtifactSummary) + mock_artifact_1.artifact_type = LineageSourceEnum.IMAGE.value + mock_artifact_2 = Mock(spec=ArtifactSummary) + mock_artifact_2.artifact_type = LineageSourceEnum.IMAGE.value + mock_artifacts = [mock_artifact_1, mock_artifact_2] + mock_artifact_list.return_value = mock_artifacts + + result = _load_artifact_by_source_uri( + source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session + ) + + mock_artifact_list.assert_called_once_with( + source_uri=source_uri, sagemaker_session=sagemaker_session + ) + assert result is None + + +@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri") +def test_poll_lineage_artifact_found(mock_load_artifact): + s3_uri = "s3://mybucket/mymodel" + sagemaker_session = Mock(spec=Session) + mock_artifact = Mock(spec=ArtifactSummary) + + with patch("time.time") as mock_time: + mock_time.return_value = 0 + + mock_load_artifact.return_value = mock_artifact + + result = _poll_lineage_artifact( + s3_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session + ) + + assert result == mock_artifact + mock_load_artifact.assert_has_calls( + [ + call(s3_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session), + ] + ) + + +@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri") +def test_poll_lineage_artifact_not_found(mock_load_artifact): + s3_uri = "s3://mybucket/mymodel" + artifact_type = LineageSourceEnum.MODEL_DATA.value + sagemaker_session = Mock(spec=Session) + + with patch("time.time") as mock_time: + mock_time_values = [0.0, 1.0, LINEAGE_POLLER_MAX_TIMEOUT_SECS + 1.0] + mock_time.side_effect = mock_time_values + + with patch("time.sleep"): + mock_load_artifact.side_effect = [None, None, None] + + result = _poll_lineage_artifact(s3_uri, artifact_type, sagemaker_session) + + assert result is None + + +@pytest.mark.parametrize( + "mlflow_model_path, expected_output", + [ + ("runs:/abc123", MLFLOW_RUN_ID), + ("models:/my-model/1", MLFLOW_REGISTRY_PATH), + ( + "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model-package", + MLFLOW_MODEL_PACKAGE_PATH, + ), + ("s3://my-bucket/path/to/model", MLFLOW_S3_PATH), + ], +) +def test_get_mlflow_model_path_type_valid(mlflow_model_path, expected_output): + result = _get_mlflow_model_path_type(mlflow_model_path) + assert result == expected_output + + +@patch("os.path.exists") +def test_get_mlflow_model_path_type_valid_local_path(mock_path_exists): + valid_path = "/path/to/mlflow_model" + mock_path_exists.side_effect = lambda path: path == valid_path + result = _get_mlflow_model_path_type(valid_path) + assert result == MLFLOW_LOCAL_PATH + + +def test_get_mlflow_model_path_type_invalid(): + invalid_path = "invalid_path" + with pytest.raises(ValueError, match=f"Invalid MLflow model path: {invalid_path}"): + _get_mlflow_model_path_type(invalid_path) + + +@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type") +@patch("sagemaker.lineage.artifact.Artifact.create") +def test_create_mlflow_model_path_lineage_artifact_success( + mock_artifact_create, mock_get_mlflow_path_type +): + mlflow_model_path = "runs:/Ab12Cd34" + sagemaker_session = Mock(spec=Session) + mock_artifact = Mock(spec=Artifact) + mock_get_mlflow_path_type.return_value = "mlflow_run_id" + mock_artifact_create.return_value = mock_artifact + + result = _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session) + + assert result == mock_artifact + mock_get_mlflow_path_type.assert_called_once_with(mlflow_model_path) + mock_artifact_create.assert_called_once_with( + source_uri=mlflow_model_path, + artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, + artifact_name="mlflow_run_id", + properties={"model_builder_input_model_data_type": "mlflow_run_id"}, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type") +@patch("sagemaker.lineage.artifact.Artifact.create") +def test_create_mlflow_model_path_lineage_artifact_validation_exception( + mock_artifact_create, mock_get_mlflow_path_type +): + mlflow_model_path = "runs:/Ab12Cd34" + sagemaker_session = Mock(spec=Session) + mock_get_mlflow_path_type.return_value = "mlflow_run_id" + mock_artifact_create.side_effect = ClientError( + error_response={"Error": {"Code": "ValidationException"}}, operation_name="CreateArtifact" + ) + + result = _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session) + + assert result is None + + +@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type") +@patch("sagemaker.lineage.artifact.Artifact.create") +def test_create_mlflow_model_path_lineage_artifact_other_exception( + mock_artifact_create, mock_get_mlflow_path_type +): + mlflow_model_path = "runs:/Ab12Cd34" + sagemaker_session = Mock(spec=Session) + mock_get_mlflow_path_type.return_value = "mlflow_run_id" + mock_artifact_create.side_effect = ClientError( + error_response={"Error": {"Code": "SomeOtherException"}}, operation_name="CreateArtifact" + ) + + with pytest.raises(ClientError): + _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session) + + +@patch("sagemaker.serve.utils.lineage_utils._create_mlflow_model_path_lineage_artifact") +@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri") +def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_existing( + mock_load_artifact, mock_create_artifact +): + mlflow_model_path = "runs:/Ab12Cd34" + sagemaker_session = Mock(spec=Session) + mock_artifact_summary = Mock(spec=ArtifactSummary) + mock_load_artifact.return_value = mock_artifact_summary + + result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( + mlflow_model_path, sagemaker_session + ) + + assert result == mock_artifact_summary + mock_load_artifact.assert_called_once_with( + mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + ) + mock_create_artifact.assert_not_called() + + +@patch("sagemaker.serve.utils.lineage_utils._create_mlflow_model_path_lineage_artifact") +@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri") +def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_create( + mock_load_artifact, mock_create_artifact +): + mlflow_model_path = "runs:/Ab12Cd34" + sagemaker_session = Mock(spec=Session) + mock_artifact = Mock(spec=Artifact) + mock_load_artifact.return_value = None + mock_create_artifact.return_value = mock_artifact + + result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( + mlflow_model_path, sagemaker_session + ) + + assert result == mock_artifact + mock_load_artifact.assert_called_once_with( + mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + ) + mock_create_artifact.assert_called_once_with(mlflow_model_path, sagemaker_session) + + +@patch("sagemaker.lineage.association.Association.create") +def test_add_association_between_artifacts_success(mock_association_create): + mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123" + autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456" + sagemaker_session = Mock(spec=Session) + + _add_association_between_artifacts( + mlflow_model_path_artifact_arn, + autogenerated_model_data_artifact_arn, + sagemaker_session, + ) + + mock_association_create.assert_called_once_with( + source_arn=mlflow_model_path_artifact_arn, + destination_arn=autogenerated_model_data_artifact_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.lineage.association.Association.create") +def test_add_association_between_artifacts_validation_exception(mock_association_create): + mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123" + autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456" + sagemaker_session = Mock(spec=Session) + mock_association_create.side_effect = ClientError( + error_response={"Error": {"Code": "ValidationException"}}, + operation_name="CreateAssociation", + ) + + _add_association_between_artifacts( + mlflow_model_path_artifact_arn, + autogenerated_model_data_artifact_arn, + sagemaker_session, + ) + + +@patch("sagemaker.lineage.association.Association.create") +def test_add_association_between_artifacts_other_exception(mock_association_create): + mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123" + autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456" + sagemaker_session = Mock(spec=Session) + mock_association_create.side_effect = ClientError( + error_response={"Error": {"Code": "SomeOtherException"}}, operation_name="CreateAssociation" + ) + + with pytest.raises(ClientError): + _add_association_between_artifacts( + mlflow_model_path_artifact_arn, + autogenerated_model_data_artifact_arn, + sagemaker_session, + ) + + +@patch("sagemaker.serve.utils.lineage_utils._poll_lineage_artifact") +@patch( + "sagemaker.serve.utils.lineage_utils._retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact" +) +@patch("sagemaker.serve.utils.lineage_utils._add_association_between_artifacts") +def test_maintain_lineage_tracking_for_mlflow_model_success( + mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact +): + mlflow_model_path = "runs:/Ab12Cd34" + s3_upload_path = "s3://mybucket/path/to/model" + sagemaker_session = Mock(spec=Session) + mock_model_data_artifact = Mock(spec=ArtifactSummary) + mock_mlflow_model_artifact = Mock(spec=Artifact) + mock_poll_artifact.return_value = mock_model_data_artifact + mock_retrieve_create_artifact.return_value = mock_mlflow_model_artifact + + _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path, s3_upload_path, sagemaker_session + ) + + mock_poll_artifact.assert_called_once_with( + s3_uri=s3_upload_path, + artifact_type=LineageSourceEnum.MODEL_DATA.value, + sagemaker_session=sagemaker_session, + ) + mock_retrieve_create_artifact.assert_called_once_with( + mlflow_model_path=mlflow_model_path, sagemaker_session=sagemaker_session + ) + mock_add_association.assert_called_once_with( + mlflow_model_path_artifact_arn=mock_mlflow_model_artifact.artifact_arn, + autogenerated_model_data_artifact_arn=mock_model_data_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.serve.utils.lineage_utils._poll_lineage_artifact") +@patch( + "sagemaker.serve.utils.lineage_utils._retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact" +) +@patch("sagemaker.serve.utils.lineage_utils._add_association_between_artifacts") +def test_maintain_lineage_tracking_for_mlflow_model_no_model_data_artifact( + mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact +): + mlflow_model_path = "runs:/Ab12Cd34" + s3_upload_path = "s3://mybucket/path/to/model" + sagemaker_session = Mock(spec=Session) + mock_poll_artifact.return_value = None + mock_retrieve_create_artifact.return_value = None + + _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path, s3_upload_path, sagemaker_session + ) + + mock_poll_artifact.assert_called_once_with( + s3_uri=s3_upload_path, + artifact_type=LineageSourceEnum.MODEL_DATA.value, + sagemaker_session=sagemaker_session, + ) + mock_retrieve_create_artifact.assert_not_called() + mock_add_association.assert_not_called()