|
21 | 21 | import numpy as np |
22 | 22 |
|
23 | 23 | from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY |
| 24 | +from sagemaker.deserializers import BaseDeserializer |
24 | 25 | from sagemaker.model_monitor import DataCaptureConfig |
| 26 | +from sagemaker.serializers import BaseSerializer |
25 | 27 | from sagemaker.session import production_variant, Session |
26 | 28 | from sagemaker.utils import name_from_base |
27 | 29 |
|
@@ -59,27 +61,28 @@ def __init__( |
59 | 61 | object, used for SageMaker interactions (default: None). If not |
60 | 62 | specified, one is created using the default AWS configuration |
61 | 63 | chain. |
62 | | - serializer (callable): Accepts a single argument, the input data, |
63 | | - and returns a sequence of bytes. It may provide a |
64 | | - ``content_type`` attribute that defines the endpoint request |
65 | | - content type. If not specified, a sequence of bytes is expected |
66 | | - for the data. |
67 | | - deserializer (callable): Accepts two arguments, the result data and |
68 | | - the response content type, and returns a sequence of bytes. It |
69 | | - may provide a ``content_type`` attribute that defines the |
70 | | - endpoint response's "Accept" content type. If not specified, a |
71 | | - sequence of bytes is expected for the data. |
| 64 | + serializer (sagemaker.serializers.BaseSerializer): A serializer |
| 65 | + object, used to encode data for an inference endpoint |
| 66 | + (default: None). |
| 67 | + deserializer (sagemaker.deserializers.BaseDeserializer): A |
| 68 | + deserializer object, used to decode data from an inference |
| 69 | + endpoint (default: None). |
72 | 70 | content_type (str): The invocation's "ContentType", overriding any |
73 | | - ``content_type`` from the serializer (default: None). |
| 71 | + ``CONTENT_TYPE`` from the serializer (default: None). |
74 | 72 | accept (str): The invocation's "Accept", overriding any accept from |
75 | 73 | the deserializer (default: None). |
76 | 74 | """ |
| 75 | + if serializer is not None and not isinstance(serializer, BaseSerializer): |
| 76 | + serializer = LegacySerializer(serializer) |
| 77 | + if deserializer is not None and not isinstance(deserializer, BaseDeserializer): |
| 78 | + deserializer = LegacyDeserializer(deserializer) |
| 79 | + |
77 | 80 | self.endpoint_name = endpoint_name |
78 | 81 | self.sagemaker_session = sagemaker_session or Session() |
79 | 82 | self.serializer = serializer |
80 | 83 | self.deserializer = deserializer |
81 | | - self.content_type = content_type or getattr(serializer, "content_type", None) |
82 | | - self.accept = accept or getattr(deserializer, "accept", None) |
| 84 | + self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None) |
| 85 | + self.accept = accept or getattr(deserializer, "ACCEPT", None) |
83 | 86 | self._endpoint_config_name = self._get_endpoint_config_name() |
84 | 87 | self._model_names = self._get_model_names() |
85 | 88 |
|
@@ -120,8 +123,10 @@ def _handle_response(self, response): |
120 | 123 | """ |
121 | 124 | response_body = response["Body"] |
122 | 125 | if self.deserializer is not None: |
| 126 | + if not isinstance(self.deserializer, BaseDeserializer): |
| 127 | + self.deserializer = LegacyDeserializer(self.deserializer) |
123 | 128 | # It's the deserializer's responsibility to close the stream |
124 | | - return self.deserializer(response_body, response["ContentType"]) |
| 129 | + return self.deserializer.deserialize(response_body, response["ContentType"]) |
125 | 130 | data = response_body.read() |
126 | 131 | response_body.close() |
127 | 132 | return data |
@@ -152,7 +157,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe |
152 | 157 | args["TargetVariant"] = target_variant |
153 | 158 |
|
154 | 159 | if self.serializer is not None: |
155 | | - data = self.serializer(data) |
| 160 | + if not isinstance(self.serializer, BaseSerializer): |
| 161 | + self.serializer = LegacySerializer(self.serializer) |
| 162 | + data = self.serializer.serialize(data) |
156 | 163 |
|
157 | 164 | args["Body"] = data |
158 | 165 | return args |
@@ -406,6 +413,88 @@ def _get_model_names(self): |
406 | 413 | return [d["ModelName"] for d in production_variants] |
407 | 414 |
|
408 | 415 |
|
| 416 | +class LegacySerializer(BaseSerializer): |
| 417 | + """Wrapper that makes legacy serializers forward compatibile.""" |
| 418 | + |
| 419 | + def __init__(self, serializer): |
| 420 | + """Initialize a ``LegacySerializer``. |
| 421 | +
|
| 422 | + Args: |
| 423 | + serializer (callable): A legacy serializer. |
| 424 | + """ |
| 425 | + self.serializer = serializer |
| 426 | + self.content_type = getattr(serializer, "content_type", None) |
| 427 | + |
| 428 | + def __call__(self, *args, **kwargs): |
| 429 | + """Wraps the call method of the legacy serializer. |
| 430 | +
|
| 431 | + Args: |
| 432 | + data (object): Data to be serialized. |
| 433 | +
|
| 434 | + Returns: |
| 435 | + object: Serialized data used for a request. |
| 436 | + """ |
| 437 | + return self.serializer(*args, **kwargs) |
| 438 | + |
| 439 | + def serialize(self, data): |
| 440 | + """Wraps the call method of the legacy serializer. |
| 441 | +
|
| 442 | + Args: |
| 443 | + data (object): Data to be serialized. |
| 444 | +
|
| 445 | + Returns: |
| 446 | + object: Serialized data used for a request. |
| 447 | + """ |
| 448 | + return self.serializer(data) |
| 449 | + |
| 450 | + @property |
| 451 | + def CONTENT_TYPE(self): |
| 452 | + """The MIME type of the data sent to the inference endpoint.""" |
| 453 | + return self.content_type |
| 454 | + |
| 455 | + |
| 456 | +class LegacyDeserializer(BaseDeserializer): |
| 457 | + """Wrapper that makes legacy deserializers forward compatibile.""" |
| 458 | + |
| 459 | + def __init__(self, deserializer): |
| 460 | + """Initialize a ``LegacyDeserializer``. |
| 461 | +
|
| 462 | + Args: |
| 463 | + deserializer (callable): A legacy deserializer. |
| 464 | + """ |
| 465 | + self.deserializer = deserializer |
| 466 | + self.accept = getattr(deserializer, "accept", None) |
| 467 | + |
| 468 | + def __call__(self, *args, **kwargs): |
| 469 | + """Wraps the call method of the legacy deserializer. |
| 470 | +
|
| 471 | + Args: |
| 472 | + data (object): Data to be deserialized. |
| 473 | + content_type (str): The MIME type of the data. |
| 474 | +
|
| 475 | + Returns: |
| 476 | + object: The data deserialized into an object. |
| 477 | + """ |
| 478 | + return self.deserializer(*args, **kwargs) |
| 479 | + |
| 480 | + def deserialize(self, data, content_type): |
| 481 | + """Wraps the call method of the legacy deserializer. |
| 482 | +
|
| 483 | + Args: |
| 484 | + data (object): Data to be deserialized. |
| 485 | + content_type (str): The MIME type of the data. |
| 486 | +
|
| 487 | + Returns: |
| 488 | + object: The data deserialized into an object. |
| 489 | + """ |
| 490 | + return self.deserializer(data, content_type) |
| 491 | + |
| 492 | + @property |
| 493 | + def ACCEPT(self): |
| 494 | + """The content type that is expected from the inference endpoint.""" |
| 495 | + return self.accept |
| 496 | + |
| 497 | + |
409 | 498 | class _CsvSerializer(object): |
410 | 499 | """Placeholder docstring""" |
411 | 500 |
|
|
0 commit comments