Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
92451cb
Initial commit for tensorflow_serving support of MLflow
jiapinw Apr 18, 2024
8a723c9
Merge branch 'aws:master' into aloy-integration-ga-deployment-dev
jiapinw Apr 23, 2024
8ba826b
Merge branch 'aws:master' into aloy-integration-ga-deployment-dev
jiapinw Apr 23, 2024
aa7b0c6
Add integ tests for mlflow tf_serving
jiapinw Apr 26, 2024
a549adb
fix style issues
jiapinw Apr 26, 2024
6279dc4
Merge branch 'aws:master' into aloy-integration-ga-deployment-dev
jiapinw May 2, 2024
e198a79
remove unused attributes from tf builder
jiapinw May 2, 2024
2f565d7
Add deep ping for tf_serving local mode
jiapinw May 2, 2024
08096ff
Merge pull request #4 from jiapinw/aloy-integration-ga-deployment-dev
jiapinw May 3, 2024
8408e7a
Initial commit for lineage impl
jiapinw Apr 27, 2024
f0440f1
Merge branch 'aws:master' into aloy-integration-ga-master
jiapinw May 6, 2024
7961a74
Merge branch 'aws:master' into aloy-integration-ga-lineage-dev
jiapinw May 6, 2024
c75d7d3
Initial commit for tensorflow_serving support of MLflow
jiapinw Apr 18, 2024
590b36e
Add integ tests for mlflow tf_serving
jiapinw Apr 26, 2024
c8262b7
fix style issues
jiapinw Apr 26, 2024
47fa352
remove unused attributes from tf builder
jiapinw May 2, 2024
2739c3e
Add deep ping for tf_serving local mode
jiapinw May 2, 2024
f9adf44
Add integ tests and uts
jiapinw May 7, 2024
f2d9d36
fix local mode for tf_serving
jiapinw May 7, 2024
7cac6bc
Allow lineage tracking only in sagemaker endpoint mode
jiapinw May 7, 2024
345e17a
Merge branch 'aloy-integration-ga-master' into aloy-integration-ga-li…
jiapinw May 7, 2024
1b99418
fix regex pattern
jiapinw May 8, 2024
e45ad0f
Merge pull request #5 from jiapinw/aloy-integration-ga-lineage-dev
jiapinw May 8, 2024
480e13e
Merge branch 'aws:master' into aloy-integration-ga-master
jiapinw May 8, 2024
e7c9e59
fix style issues
jiapinw May 8, 2024
65ad677
fix regex pattern and hard coded py version in ut
jiapinw May 8, 2024
a276acc
fix missing session
jiapinw May 8, 2024
2a86534
Resolve pr comments and fix regex for mlflow registry and ids
jiapinw May 9, 2024
d03df30
Merge branch 'master' into aloy-integration-ga-master
jiapinw May 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ onnx>=1.15.0
nbformat>=5.9,<6
accelerate>=0.24.1,<=0.27.0
schema==0.7.5
tensorflow>=2.1,<=2.16
27 changes: 23 additions & 4 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.builder.tf_serving_builder import TensorflowServing
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
Expand Down Expand Up @@ -59,6 +60,7 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.utils import task
from sagemaker.serve.utils.exceptions import TaskNotFoundException
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
from sagemaker.serve.utils.hardware_detector import (
_get_gpu_info,
Expand Down Expand Up @@ -89,12 +91,13 @@
ModelServer.TORCHSERVE,
ModelServer.TRITON,
ModelServer.DJL_SERVING,
ModelServer.TENSORFLOW_SERVING,
}


# pylint: disable=attribute-defined-outside-init, disable=E1101
# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901
@dataclass
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing):
"""Class that builds a deployable model.

Args:
Expand Down Expand Up @@ -493,6 +496,12 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
self.pysdk_model.model_package_arn = new_model_package.model_package_arn
new_model_package.deploy = self._model_builder_deploy_model_package_wrapper
self.model_package = new_model_package
if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT:
Copy link
Contributor

@makungaj1 makungaj1 May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is _is_mlflow_model method defined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._is_mlflow_model = self._check_if_input_is_mlflow_model()

_maintain_lineage_tracking_for_mlflow_model(
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
s3_upload_path=self.s3_upload_path,
sagemaker_session=self.sagemaker_session,
)
return new_model_package

def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs):
Expand Down Expand Up @@ -551,12 +560,19 @@ def _model_builder_deploy_wrapper(

if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True
return self._original_deploy(
predictor = self._original_deploy(
*args,
instance_type=instance_type,
initial_instance_count=initial_instance_count,
**kwargs,
)
if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT:
_maintain_lineage_tracking_for_mlflow_model(
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
s3_upload_path=self.s3_upload_path,
sagemaker_session=self.sagemaker_session,
)
return predictor

def _overwrite_mode_in_deploy(self, overwrite_mode: str):
"""Mode overwritten by customer during model.deploy()"""
Expand Down Expand Up @@ -728,7 +744,7 @@ def build( # pylint: disable=R0911
" for production at this moment."
)
self._initialize_for_mlflow()
_validate_input_for_mlflow(self.model_server)
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))

if isinstance(self.model, str):
model_task = None
Expand Down Expand Up @@ -767,6 +783,9 @@ def build( # pylint: disable=R0911
if self.model_server == ModelServer.TRITON:
return self._build_for_triton()

if self.model_server == ModelServer.TENSORFLOW_SERVING:
return self._build_for_tensorflow_serving()

raise ValueError("%s model server is not supported" % self.model_server)

def save(
Expand Down
129 changes: 129 additions & 0 deletions src/sagemaker/serve/builder/tf_serving_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Holds mixin logic to support deployment of Model ID"""
from __future__ import absolute_import
import logging
import os
from pathlib import Path
from abc import ABC, abstractmethod

from sagemaker import Session
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving
from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor

logger = logging.getLogger(__name__)

_TF_SERVING_MODEL_BUILDER_ENTRY_POINT = "inference.py"
_CODE_FOLDER = "code"


# pylint: disable=attribute-defined-outside-init, disable=E1101
class TensorflowServing(ABC):
"""TensorflowServing build logic for ModelBuilder()"""

def __init__(self):
self.model = None
self.serve_settings = None
self.sagemaker_session = None
self.model_path = None
self.dependencies = None
self.modes = None
self.mode = None
self.model_server = None
self.image_uri = None
self._is_custom_image_uri = False
self.image_config = None
self.vpc_config = None
self._original_deploy = None
self.secret_key = None
self.engine = None
self.pysdk_model = None
self.schema_builder = None
self.env_vars = None

@abstractmethod
def _prepare_for_mode(self):
"""Prepare model artifacts based on mode."""

@abstractmethod
def _get_client_translators(self):
"""Set up client marshaller based on schema builder."""

def _save_schema_builder(self):
"""Save schema builder for tensorflow serving."""
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)

code_path = Path(self.model_path).joinpath("code")
save_pkl(code_path, self.schema_builder)

def _get_tensorflow_predictor(
self, endpoint_name: str, sagemaker_session: Session
) -> TensorFlowPredictor:
"""Creates a TensorFlowPredictor object"""
serializer, deserializer = self._get_client_translators()

return TensorFlowPredictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=serializer,
deserializer=deserializer,
)

def _validate_for_tensorflow_serving(self):
"""Validate for tensorflow serving"""
if not getattr(self, "_is_mlflow_model", False):
raise ValueError("Tensorflow Serving is currently only supported for mlflow models.")

def _create_tensorflow_model(self):
"""Creates a TensorFlow model object"""
self.pysdk_model = TensorFlowModel(
image_uri=self.image_uri,
image_config=self.image_config,
vpc_config=self.vpc_config,
model_data=self.s3_upload_path,
role=self.serve_settings.role_arn,
env=self.env_vars,
sagemaker_session=self.sagemaker_session,
predictor_cls=self._get_tensorflow_predictor,
)

self.pysdk_model.mode = self.mode
self.pysdk_model.modes = self.modes
self.pysdk_model.serve_settings = self.serve_settings

self._original_deploy = self.pysdk_model.deploy
self.pysdk_model.deploy = self._model_builder_deploy_wrapper
self._original_register = self.pysdk_model.register
self.pysdk_model.register = self._model_builder_register_wrapper
self.model_package = None
return self.pysdk_model

def _build_for_tensorflow_serving(self):
"""Build the model for Tensorflow Serving"""
self._validate_for_tensorflow_serving()
self._save_schema_builder()

if not self.image_uri:
raise ValueError("image_uri is not set for tensorflow serving")

self.secret_key = prepare_for_tf_serving(
model_path=self.model_path,
shared_libs=self.shared_libs,
dependencies=self.dependencies,
)

self._prepare_for_mode()

return self._create_tensorflow_model()
17 changes: 16 additions & 1 deletion src/sagemaker/serve/mode/local_container_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import docker

from sagemaker.base_predictor import PredictorBase
from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.utils.logging_agent import pull_logs
Expand All @@ -34,7 +35,12 @@


class LocalContainerMode(
LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalMultiModelServer
LocalTorchServe,
LocalDJLServing,
LocalTritonServer,
LocalTgiServing,
LocalMultiModelServer,
LocalTensorflowServing,
):
"""A class that holds methods to deploy model to a container in local environment"""

Expand Down Expand Up @@ -141,6 +147,15 @@ def create_server(
env_vars=env_vars if env_vars else self.env_vars,
)
self._ping_container = self._multi_model_server_deep_ping
elif self.model_server == ModelServer.TENSORFLOW_SERVING:
self._start_tensorflow_serving(
client=self.client,
image=image,
model_path=model_path if model_path else self.model_path,
secret_key=secret_key,
env_vars=env_vars if env_vars else self.env_vars,
)
self._ping_container = self._tensorflow_serving_deep_ping

# allow some time for container to be ready
time.sleep(10)
Expand Down
11 changes: 11 additions & 0 deletions src/sagemaker/serve/mode/sagemaker_endpoint_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from typing import Type

from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
from sagemaker.session import Session
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.spec.inference_spec import InferenceSpec
Expand All @@ -24,6 +25,7 @@ class SageMakerEndpointMode(
SageMakerDjlServing,
SageMakerTgiServing,
SageMakerMultiModelServer,
SageMakerTensorflowServing,
):
"""Holds the required method to deploy a model to a SageMaker Endpoint"""

Expand Down Expand Up @@ -107,4 +109,13 @@ def prepare(
image=image,
)

if self.model_server == ModelServer.TENSORFLOW_SERVING:
return self._upload_tensorflow_serving_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
)

raise ValueError("%s model server is not supported" % self.model_server)
20 changes: 15 additions & 5 deletions src/sagemaker/serve/model_format/mlflow/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
"py39": "1.13.1",
"py310": "2.2.0",
}
MODEL_PACAKGE_ARN_REGEX = (
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/[" r"a-zA-Z0-9\-_\/\.]+$"
)
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+\/[a-zA-Z0-9\-_\/\.]*$"
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
MLFLOW_METADATA_FILE = "MLmodel"
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"
Expand All @@ -34,8 +40,12 @@
"spark": "pyspark",
"onnx": "onnxruntime",
}
FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = [ # will extend to keras and tf
"sklearn",
"pytorch",
"xgboost",
]
TENSORFLOW_SAVED_MODEL_NAME = "saved_model.pb"
FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = {
"sklearn": "sklearn",
"pytorch": "pytorch",
"xgboost": "xgboost",
"tensorflow": "tensorflow",
"keras": "tensorflow",
}
FLAVORS_DEFAULT_WITH_TF_SERVING = ["keras", "tensorflow"]
Loading