-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models #4662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models #4662
Changes from 27 commits
92451cb
8a723c9
8ba826b
aa7b0c6
a549adb
6279dc4
e198a79
2f565d7
08096ff
8408e7a
f0440f1
7961a74
c75d7d3
590b36e
c8262b7
47fa352
2739c3e
f9adf44
f2d9d36
7cac6bc
345e17a
1b99418
e45ad0f
480e13e
e7c9e59
65ad677
a276acc
2a86534
d03df30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,3 +36,4 @@ onnx>=1.15.0 | |
| nbformat>=5.9,<6 | ||
| accelerate>=0.24.1,<=0.27.0 | ||
| schema==0.7.5 | ||
| tensorflow>=2.1 | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |||
| from sagemaker.serializers import NumpySerializer, TorchTensorSerializer | ||||
| from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer | ||||
| from sagemaker.serve.builder.schema_builder import SchemaBuilder | ||||
| from sagemaker.serve.builder.tf_serving_builder import TensorflowServing | ||||
| from sagemaker.serve.mode.function_pointers import Mode | ||||
| from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode | ||||
| from sagemaker.serve.mode.local_container_mode import LocalContainerMode | ||||
|
|
@@ -59,6 +60,7 @@ | |||
| from sagemaker.serve.spec.inference_spec import InferenceSpec | ||||
| from sagemaker.serve.utils import task | ||||
| from sagemaker.serve.utils.exceptions import TaskNotFoundException | ||||
| from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model | ||||
| from sagemaker.serve.utils.predictors import _get_local_mode_predictor | ||||
| from sagemaker.serve.utils.hardware_detector import ( | ||||
| _get_gpu_info, | ||||
|
|
@@ -89,12 +91,13 @@ | |||
| ModelServer.TORCHSERVE, | ||||
| ModelServer.TRITON, | ||||
| ModelServer.DJL_SERVING, | ||||
| ModelServer.TENSORFLOW_SERVING, | ||||
| } | ||||
|
|
||||
|
|
||||
| # pylint: disable=attribute-defined-outside-init, disable=E1101 | ||||
| # pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901 | ||||
| @dataclass | ||||
| class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): | ||||
| class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing): | ||||
| """Class that builds a deployable model. | ||||
|
|
||||
| Args: | ||||
|
|
@@ -493,6 +496,12 @@ def _model_builder_register_wrapper(self, *args, **kwargs): | |||
| self.pysdk_model.model_package_arn = new_model_package.model_package_arn | ||||
| new_model_package.deploy = self._model_builder_deploy_model_package_wrapper | ||||
| self.model_package = new_model_package | ||||
| if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT: | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||
| _maintain_lineage_tracking_for_mlflow_model( | ||||
| mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], | ||||
| s3_upload_path=self.s3_upload_path, | ||||
| sagemaker_session=self.sagemaker_session, | ||||
| ) | ||||
| return new_model_package | ||||
|
|
||||
| def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs): | ||||
|
|
@@ -551,12 +560,19 @@ def _model_builder_deploy_wrapper( | |||
|
|
||||
| if "endpoint_logging" not in kwargs: | ||||
| kwargs["endpoint_logging"] = True | ||||
| return self._original_deploy( | ||||
| predictor = self._original_deploy( | ||||
| *args, | ||||
| instance_type=instance_type, | ||||
| initial_instance_count=initial_instance_count, | ||||
| **kwargs, | ||||
| ) | ||||
| if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT: | ||||
| _maintain_lineage_tracking_for_mlflow_model( | ||||
| mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], | ||||
| s3_upload_path=self.s3_upload_path, | ||||
| sagemaker_session=self.sagemaker_session, | ||||
| ) | ||||
| return predictor | ||||
|
|
||||
| def _overwrite_mode_in_deploy(self, overwrite_mode: str): | ||||
| """Mode overwritten by customer during model.deploy()""" | ||||
|
|
@@ -728,7 +744,7 @@ def build( # pylint: disable=R0911 | |||
| " for production at this moment." | ||||
| ) | ||||
| self._initialize_for_mlflow() | ||||
| _validate_input_for_mlflow(self.model_server) | ||||
| _validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR")) | ||||
|
|
||||
| if isinstance(self.model, str): | ||||
| model_task = None | ||||
|
|
@@ -767,6 +783,9 @@ def build( # pylint: disable=R0911 | |||
| if self.model_server == ModelServer.TRITON: | ||||
| return self._build_for_triton() | ||||
|
|
||||
| if self.model_server == ModelServer.TENSORFLOW_SERVING: | ||||
| return self._build_for_tensorflow_serving() | ||||
|
|
||||
| raise ValueError("%s model server is not supported" % self.model_server) | ||||
|
|
||||
| def save( | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Holds mixin logic to support deployment of Model ID""" | ||
| from __future__ import absolute_import | ||
| import logging | ||
| import os | ||
| from pathlib import Path | ||
| from abc import ABC, abstractmethod | ||
|
|
||
| from sagemaker import Session | ||
| from sagemaker.serve.detector.pickler import save_pkl | ||
| from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving | ||
| from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _TF_SERVING_MODEL_BUILDER_ENTRY_POINT = "inference.py" | ||
| _CODE_FOLDER = "code" | ||
|
|
||
|
|
||
| # pylint: disable=attribute-defined-outside-init, disable=E1101 | ||
| class TensorflowServing(ABC): | ||
| """TensorflowServing build logic for ModelBuilder()""" | ||
|
|
||
| def __init__(self): | ||
| self.model = None | ||
| self.serve_settings = None | ||
| self.sagemaker_session = None | ||
| self.model_path = None | ||
| self.dependencies = None | ||
| self.modes = None | ||
| self.mode = None | ||
| self.model_server = None | ||
| self.image_uri = None | ||
| self._is_custom_image_uri = False | ||
| self.image_config = None | ||
| self.vpc_config = None | ||
| self._original_deploy = None | ||
| self.secret_key = None | ||
| self.engine = None | ||
| self.pysdk_model = None | ||
| self.schema_builder = None | ||
| self.env_vars = None | ||
|
|
||
| @abstractmethod | ||
| def _prepare_for_mode(self): | ||
jiapinw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Placeholder docstring""" | ||
|
|
||
| @abstractmethod | ||
| def _get_client_translators(self): | ||
| """Placeholder docstring""" | ||
|
|
||
| def _save_schema_builder(self): | ||
| """Save schema builder for tensorflow serving.""" | ||
| if not os.path.exists(self.model_path): | ||
| os.makedirs(self.model_path) | ||
|
|
||
| code_path = Path(self.model_path).joinpath("code") | ||
| save_pkl(code_path, self.schema_builder) | ||
|
|
||
| def _get_tensorflow_predictor( | ||
| self, endpoint_name: str, sagemaker_session: Session | ||
| ) -> TensorFlowPredictor: | ||
| """Placeholder docstring""" | ||
|
||
| serializer, deserializer = self._get_client_translators() | ||
|
|
||
| return TensorFlowPredictor( | ||
| endpoint_name=endpoint_name, | ||
| sagemaker_session=sagemaker_session, | ||
| serializer=serializer, | ||
| deserializer=deserializer, | ||
| ) | ||
|
|
||
| def _validate_for_tensorflow_serving(self): | ||
| """Validate for tensorflow serving""" | ||
| if not getattr(self, "_is_mlflow_model", False): | ||
| raise ValueError("Tensorflow Serving is currently only supported for mlflow models.") | ||
|
|
||
| def _create_tensorflow_model(self): | ||
| """Placeholder docstring""" | ||
| # TODO: we should create model as per the framework | ||
| self.pysdk_model = TensorFlowModel( | ||
jiapinw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| image_uri=self.image_uri, | ||
| image_config=self.image_config, | ||
| vpc_config=self.vpc_config, | ||
| model_data=self.s3_upload_path, | ||
| role=self.serve_settings.role_arn, | ||
| env=self.env_vars, | ||
| sagemaker_session=self.sagemaker_session, | ||
| predictor_cls=self._get_tensorflow_predictor, | ||
| ) | ||
|
|
||
| # store the modes in the model so that we may | ||
| # reference the configurations for local deploy() & predict() | ||
| self.pysdk_model.mode = self.mode | ||
| self.pysdk_model.modes = self.modes | ||
| self.pysdk_model.serve_settings = self.serve_settings | ||
|
|
||
| # dynamically generate a method to direct model.deploy() logic based on mode | ||
| # unique method to models created via ModelBuilder() | ||
| self._original_deploy = self.pysdk_model.deploy | ||
| self.pysdk_model.deploy = self._model_builder_deploy_wrapper | ||
| self._original_register = self.pysdk_model.register | ||
| self.pysdk_model.register = self._model_builder_register_wrapper | ||
| self.model_package = None | ||
| return self.pysdk_model | ||
|
|
||
| def _build_for_tensorflow_serving(self): | ||
| """Build the model for Tensorflow Serving""" | ||
| self._validate_for_tensorflow_serving() | ||
| self._save_schema_builder() | ||
|
|
||
| if not self.image_uri: | ||
| raise ValueError("image_uri is not set for tensorflow serving") | ||
|
|
||
| self.secret_key = prepare_for_tf_serving( | ||
| model_path=self.model_path, | ||
| shared_libs=self.shared_libs, | ||
| dependencies=self.dependencies, | ||
| ) | ||
|
|
||
| self._prepare_for_mode() | ||
|
|
||
| return self._create_tensorflow_model() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be added to the extras requirements?
I think here: a lower bound is optional, upperbound is required for versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is only needed for integ tests since we need to convert to and from tensors