diff --git a/src/sagemaker/serve/marshalling/triton_translator.py b/src/sagemaker/serve/marshalling/triton_translator.py index 4c3973814c..034f536fa6 100644 --- a/src/sagemaker/serve/marshalling/triton_translator.py +++ b/src/sagemaker/serve/marshalling/triton_translator.py @@ -16,6 +16,8 @@ def __init__(self) -> None: import torch self.convert_from_numpy = torch.from_numpy # pylint: disable=E1101 + self.CONTENT_TYPE = "tensor/pt" + self.ACCEPT = "tensor/pt" def serialize(self, data, content_type: str = "tensor/pt"): """Translate torch.Tensor to numpy ndarray""" @@ -45,6 +47,8 @@ def __init__(self) -> None: import tensorflow as tf self.convert_to_tensor = tf.convert_to_tensor + self.CONTENT_TYPE = "tensor/tf" + self.ACCEPT = "tensor/tf" def serialize(self, data, content_type: str = "tensor/tf"): """Translate tf.Tensor to numpy ndarray""" @@ -70,6 +74,10 @@ def _deserializer(self): class NumpyTranslator: """A dummy class to make sure the translator interface is aligned""" + def __init__(self) -> None: + self.CONTENT_TYPE = "application/x-npy" + self.ACCEPT = "application/x-npy" + def serialize(self, data, content_type: str = "application/x-npy"): """Placeholder docstring""" return data @@ -86,6 +94,10 @@ def _deserializer(self): class ListTranslator: """Translate python list from and to numpy.ndarray""" + def __init__(self) -> None: + self.CONTENT_TYPE = "application/list" + self.ACCEPT = "application/list" + def serialize(self, data, content_type: str = "application/list"): """Placeholder docstring""" try: diff --git a/src/sagemaker/serve/model_server/triton/triton_builder.py b/src/sagemaker/serve/model_server/triton/triton_builder.py index 931f67b02b..2fc895186a 100644 --- a/src/sagemaker/serve/model_server/triton/triton_builder.py +++ b/src/sagemaker/serve/model_server/triton/triton_builder.py @@ -430,6 +430,9 @@ def _create_triton_model(self) -> Type[Model]: # unique method to models created via ModelBuilder() self._original_deploy = self.pysdk_model.deploy self.pysdk_model.deploy = self._model_builder_deploy_wrapper + self._original_register = self.pysdk_model.register + self.pysdk_model.register = self._model_builder_register_wrapper + self.model_package = None return self.pysdk_model def _get_triton_predictor(self, endpoint_name: str, sagemaker_session: Session) -> Predictor: diff --git a/tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py b/tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py index d469c7f392..edb40fbb69 100644 --- a/tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py +++ b/tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py @@ -33,6 +33,7 @@ MOCK_SESSION = Mock() MOCK_MODES = Mock() MOCK_DEPLOY_WRAPPER = Mock() +MOCK_RESIGTER_WRAPPER = Mock() class pytorch: @@ -56,6 +57,7 @@ def prepare_triton_builder_for_model(self, triton_builder: Triton) -> Triton: triton_builder.sagemaker_session = MOCK_SESSION triton_builder.modes = MOCK_MODES triton_builder._model_builder_deploy_wrapper = MOCK_DEPLOY_WRAPPER + triton_builder._model_builder_register_wrapper = MOCK_RESIGTER_WRAPPER triton_builder.inference_spec = None mock_export = Mock()