From 50a305555d124c6f61d942884b50c62944330dd7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 30 Nov 2021 17:15:26 +0100 Subject: [PATCH 01/11] update --- flash/core/data/process.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 027f992017..54812714fe 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -13,16 +13,33 @@ # limitations under the License. import functools from abc import abstractmethod -from typing import Any, Mapping +from typing import Any, List, Mapping from warnings import warn from deprecate import deprecated import flash +from flash.core.data.io.input import Input from flash.core.data.io.output import Output from flash.core.data.properties import Properties +class ServeInput(Input): + def load_data(self, data: Any) -> List[Any]: + return [data] + + def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor??? + raise NotImplementedError + + @property + @abstractmethod + def example_input(self) -> str: + raise NotImplementedError + + def __call__(self, sample: Any) -> Any: + return self.deserialize(sample) + + class Deserializer(Properties): """Deserializer.""" From da90250345cf70037f2b1112ed025d079ee10abb Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 30 Nov 2021 22:28:59 +0100 Subject: [PATCH 02/11] update --- flash/audio/speech_recognition/data.py | 4 ++-- flash/core/data/batch.py | 2 +- flash/core/data/process.py | 20 +++++++------------- flash/image/data.py | 2 +- flash/image/segmentation/data.py | 2 +- flash/tabular/data.py | 2 +- flash/text/classification/data.py | 2 +- 7 files changed, 14 insertions(+), 20 deletions(-) diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index ce9ee5dd6e..0c50036e17 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -43,12 +43,12 @@ class SpeechRecognitionDeserializer(Deserializer): + @requires("audio") def __init__(self, sampling_rate: int = 16000): super().__init__() - self.sampling_rate = sampling_rate - def deserialize(self, sample: Any) -> Dict: + def serve_load_sample(self, sample: Any) -> Dict: encoded_with_padding = (sample + "===").encode("ascii") audio = base64.b64decode(encoded_with_padding) buffer = io.BytesIO(audio) diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 2f77f2b56d..694fff7d73 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -46,7 +46,7 @@ def __init__( def forward(self, sample: str): - sample = self.deserializer(sample) + sample = self.deserializer.deserialize(sample) with self._current_stage_context: with self._pre_tensor_transform_context: diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 54812714fe..afa1bf0c32 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -21,39 +21,33 @@ import flash from flash.core.data.io.input import Input from flash.core.data.io.output import Output -from flash.core.data.properties import Properties class ServeInput(Input): - def load_data(self, data: Any) -> List[Any]: - return [data] + def serve_load_data(self, data: Any) -> List[Any]: + raise NotImplementedError - def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor??? + def serve_load_sample(self, sample: Any) -> List[Any]: raise NotImplementedError @property @abstractmethod - def example_input(self) -> str: + def serve_example_input(self) -> str: raise NotImplementedError - def __call__(self, sample: Any) -> Any: - return self.deserialize(sample) - -class Deserializer(Properties): +class Deserializer(ServeInput): """Deserializer.""" def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor??? - raise NotImplementedError + sample = self.serve_load_data(sample) + return self.serve_load_sample(sample) @property @abstractmethod def example_input(self) -> str: raise NotImplementedError - def __call__(self, sample: Any) -> Any: - return self.deserialize(sample) - class DeserializerMapping(Deserializer): # TODO: This is essentially a duplicate of OutputMapping, should be abstracted away somewhere diff --git a/flash/image/data.py b/flash/image/data.py index 3d098e4c17..4329e1ec22 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -58,7 +58,7 @@ def image_loader(filepath: str): class ImageDeserializer(Deserializer): @requires("image") - def deserialize(self, data: str) -> Dict: + def serve_load_sample(self, data: str) -> Dict: encoded_with_padding = (data + "===").encode("ascii") img = base64.b64decode(encoded_with_padding) buffer = BytesIO(img) diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 69eb67b783..a9a685fc87 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -211,7 +211,7 @@ def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: class SemanticSegmentationDeserializer(ImageDeserializer): - def deserialize(self, data: str) -> torch.Tensor: + def serve_load_sample(self, data: str) -> torch.Tensor: result = super().deserialize(data) result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT]) result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape} diff --git a/flash/tabular/data.py b/flash/tabular/data.py index b0a43ce878..609d057148 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -186,7 +186,7 @@ def parameters(self) -> Dict[str, Any]: ) return parameters_state.parameters - def deserialize(self, data: str) -> Any: + def serve_load_data(self, data: str) -> Any: parameters = self.parameters df = read_csv(StringIO(data)) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 6309d951c9..3ea9d9d9e2 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -44,7 +44,7 @@ def __init__(self, backbone: str, max_length: int, use_fast: bool = True, **kwar self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=use_fast, **kwargs) self.max_length = max_length - def deserialize(self, text: str) -> Tensor: + def serve_load_sample(self, text: str) -> Tensor: return self.tokenizer(text, max_length=self.max_length, truncation=True, padding="max_length") @property From 2396db3a7b797c123a7ecf1c5b986d8921b25551 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 1 Dec 2021 10:15:17 +0100 Subject: [PATCH 03/11] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 981654fc32..a94bf78264 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874)) +- Changed the `Deserializer` subclass `ServeInput` ([#1013](https://github.com/PyTorchLightning/lightning-flash/pull/1013)) + ### Deprecated - Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) From 417ccfd7bb1479df8d5170f3fd643e8324e5d44f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 1 Dec 2021 10:28:56 +0100 Subject: [PATCH 04/11] update --- flash/core/data/process.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/flash/core/data/process.py b/flash/core/data/process.py index afa1bf0c32..3db9efeb18 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -24,26 +24,16 @@ class ServeInput(Input): - def serve_load_data(self, data: Any) -> List[Any]: - raise NotImplementedError - def serve_load_sample(self, sample: Any) -> List[Any]: raise NotImplementedError - @property - @abstractmethod - def serve_example_input(self) -> str: - raise NotImplementedError - class Deserializer(ServeInput): """Deserializer.""" - def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor??? - sample = self.serve_load_data(sample) + def deserialize(self, sample: Any) -> Any: return self.serve_load_sample(sample) - @property @abstractmethod def example_input(self) -> str: raise NotImplementedError From 666279cfbfa6e93846f20144b2a291e5f68dd668 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 1 Dec 2021 16:18:43 +0100 Subject: [PATCH 05/11] update --- flash/core/data/process.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 3db9efeb18..c37766a8ff 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from abc import abstractmethod from typing import Any, List, Mapping from warnings import warn @@ -34,7 +33,6 @@ class Deserializer(ServeInput): def deserialize(self, sample: Any) -> Any: return self.serve_load_sample(sample) - @abstractmethod def example_input(self) -> str: raise NotImplementedError From 36003ccb1c866d03c46ddd0d4ba2cfd06053fa7f Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 2 Dec 2021 11:56:35 +0100 Subject: [PATCH 06/11] update --- CHANGELOG.md | 2 +- flash/core/data/data_pipeline.py | 2 ++ flash/core/data/io/input_base.py | 23 +++++++++++++++++++++- flash/core/data/process.py | 19 ++---------------- flash/image/segmentation/data.py | 2 +- flash/tabular/data.py | 2 +- tests/core/data/io/test_input_base.py | 28 ++++++++++++++++++++++++++- 7 files changed, 56 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c62417993e..be598d55d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874)) -- Changed the `Deserializer` subclass `ServeInput` ([#1013](https://github.com/PyTorchLightning/lightning-flash/pull/1013)) +- Changed the `Deserializer` subclass `DeserializerInput` ([#1013](https://github.com/PyTorchLightning/lightning-flash/pull/1013)) - Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels`, `Classes`, `FiftyOneLabels`, `Labels`, `Logits`, `Probabilities` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011)) diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index baf0c60cd3..1ceba54fd7 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -197,6 +197,8 @@ def _resolve_function_hierarchy( prefixes += ["test"] elif stage == RunningStage.PREDICTING: prefixes += ["predict"] + elif stage == RunningStage.SERVING: + prefixes += ["serve"] prefixes += [None] diff --git a/flash/core/data/io/input_base.py b/flash/core/data/io/input_base.py index 94ef607f49..bedc59c7be 100644 --- a/flash/core/data/io/input_base.py +++ b/flash/core/data/io/input_base.py @@ -15,8 +15,9 @@ import os import sys from copy import copy, deepcopy -from typing import Any, cast, Dict, Iterable, MutableMapping, Optional, Sequence, Tuple, Union +from typing import Any, cast, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset import flash @@ -232,3 +233,23 @@ def __iter__(self): def __next__(self) -> Any: return self._call_load_sample(next(self.data_iter)) + + +class DeserializerInput(Input): + def __init__( + self, + data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, + ) -> None: + if hasattr(self, "serve_load_data"): + raise MisconfigurationException("`serve_load_data` shouldn't be implemented.") + + super().__init__(RunningStage.SERVING, data_pipeline_state=data_pipeline_state) + + def serve_load_sample(self, sample: Any) -> List[Any]: + raise NotImplementedError + + def deserialize(self, sample: Any) -> Any: + return self._call_load_sample(sample) + + def example_input(self) -> str: + raise NotImplementedError diff --git a/flash/core/data/process.py b/flash/core/data/process.py index c37766a8ff..5be1f36f90 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -12,31 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, List, Mapping +from typing import Any, Mapping from warnings import warn from deprecate import deprecated import flash -from flash.core.data.io.input import Input +from flash.core.data.io.input_base import DeserializerInput as Deserializer from flash.core.data.io.output import Output -class ServeInput(Input): - def serve_load_sample(self, sample: Any) -> List[Any]: - raise NotImplementedError - - -class Deserializer(ServeInput): - """Deserializer.""" - - def deserialize(self, sample: Any) -> Any: - return self.serve_load_sample(sample) - - def example_input(self) -> str: - raise NotImplementedError - - class DeserializerMapping(Deserializer): # TODO: This is essentially a duplicate of OutputMapping, should be abstracted away somewhere """Deserializer Mapping.""" diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index e7ffe7dcdc..1a08b2bfe9 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -212,7 +212,7 @@ def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: class SemanticSegmentationDeserializer(ImageDeserializer): def serve_load_sample(self, data: str) -> torch.Tensor: - result = super().deserialize(data) + result = super().serve_load_sample(data) result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT]) result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape} return result diff --git a/flash/tabular/data.py b/flash/tabular/data.py index 7fa81927d7..fdfd0de4c0 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -186,7 +186,7 @@ def parameters(self) -> Dict[str, Any]: ) return parameters_state.parameters - def serve_load_data(self, data: str) -> Any: + def serve_load_sample(self, data: str) -> Any: parameters = self.parameters df = read_csv(StringIO(data)) diff --git a/tests/core/data/io/test_input_base.py b/tests/core/data/io/test_input_base.py index 6382a0385c..beb7403a23 100644 --- a/tests/core/data/io/test_input_base.py +++ b/tests/core/data/io/test_input_base.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +from pytorch_lightning.utilities.exceptions import MisconfigurationException -from flash.core.data.io.input_base import Input, IterableInput +from flash.core.data.io.input_base import DeserializerInput, Input, IterableInput from flash.core.utilities.stages import RunningStage @@ -55,3 +56,28 @@ def __init__(self, *args, **kwargs): self.data = iter([1, 2, 3]) ValidIterableInput(RunningStage.TRAINING) + + +def test_serve_input(): + + server_input = DeserializerInput() + assert server_input.serving + with pytest.raises(NotImplementedError): + server_input._call_load_sample("") + + class CustomDeserializerInput(DeserializerInput): + def serve_load_data(self, data): + raise NotImplementedError + + def serve_load_sample(self, data): + return data + 1 + + with pytest.raises(MisconfigurationException, match="serve_load_data"): + serve_input = CustomDeserializerInput() + + class CustomDeserializerInput2(DeserializerInput): + def serve_load_sample(self, data): + return data + 1 + + serve_input = CustomDeserializerInput2() + assert serve_input._call_load_sample(1) == 2 From 2d0b0c840872530fe6a24a04ac30e48ed3beb6a5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 2 Dec 2021 11:58:09 +0100 Subject: [PATCH 07/11] update --- CHANGELOG.md | 2 +- flash/core/data/io/input_base.py | 2 +- flash/core/data/process.py | 2 +- tests/core/data/io/test_input_base.py | 12 ++++++------ 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be598d55d2..c62417993e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874)) -- Changed the `Deserializer` subclass `DeserializerInput` ([#1013](https://github.com/PyTorchLightning/lightning-flash/pull/1013)) +- Changed the `Deserializer` subclass `ServeInput` ([#1013](https://github.com/PyTorchLightning/lightning-flash/pull/1013)) - Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels`, `Classes`, `FiftyOneLabels`, `Labels`, `Logits`, `Probabilities` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011)) diff --git a/flash/core/data/io/input_base.py b/flash/core/data/io/input_base.py index bedc59c7be..e815895a24 100644 --- a/flash/core/data/io/input_base.py +++ b/flash/core/data/io/input_base.py @@ -235,7 +235,7 @@ def __next__(self) -> Any: return self._call_load_sample(next(self.data_iter)) -class DeserializerInput(Input): +class ServeInput(Input): def __init__( self, data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 5be1f36f90..48488b0d8d 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -18,7 +18,7 @@ from deprecate import deprecated import flash -from flash.core.data.io.input_base import DeserializerInput as Deserializer +from flash.core.data.io.input_base import ServeInput as Deserializer from flash.core.data.io.output import Output diff --git a/tests/core/data/io/test_input_base.py b/tests/core/data/io/test_input_base.py index beb7403a23..9d9a841485 100644 --- a/tests/core/data/io/test_input_base.py +++ b/tests/core/data/io/test_input_base.py @@ -14,7 +14,7 @@ import pytest from pytorch_lightning.utilities.exceptions import MisconfigurationException -from flash.core.data.io.input_base import DeserializerInput, Input, IterableInput +from flash.core.data.io.input_base import Input, IterableInput, ServeInput from flash.core.utilities.stages import RunningStage @@ -60,12 +60,12 @@ def __init__(self, *args, **kwargs): def test_serve_input(): - server_input = DeserializerInput() + server_input = ServeInput() assert server_input.serving with pytest.raises(NotImplementedError): server_input._call_load_sample("") - class CustomDeserializerInput(DeserializerInput): + class CustomServeInput(ServeInput): def serve_load_data(self, data): raise NotImplementedError @@ -73,11 +73,11 @@ def serve_load_sample(self, data): return data + 1 with pytest.raises(MisconfigurationException, match="serve_load_data"): - serve_input = CustomDeserializerInput() + serve_input = CustomServeInput() - class CustomDeserializerInput2(DeserializerInput): + class CustomServeInput2(ServeInput): def serve_load_sample(self, data): return data + 1 - serve_input = CustomDeserializerInput2() + serve_input = CustomServeInput2() assert serve_input._call_load_sample(1) == 2 From 74ce353781f520814fdb52becbbdba49dd891f84 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 2 Dec 2021 13:13:45 +0100 Subject: [PATCH 08/11] update --- flash/audio/speech_recognition/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index 0c50036e17..df8c7d17a9 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -44,8 +44,8 @@ class SpeechRecognitionDeserializer(Deserializer): @requires("audio") - def __init__(self, sampling_rate: int = 16000): - super().__init__() + def __init__(self, sampling_rate: int = 16000, **kwargs): + super().__init__(**kwargs) self.sampling_rate = sampling_rate def serve_load_sample(self, sample: Any) -> Dict: From 9bfe9c2ee6b5bc40a3d3ab68977ca5688c71aee8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 2 Dec 2021 13:44:12 +0100 Subject: [PATCH 09/11] update --- flash/core/data/io/input_base.py | 3 +++ flash/core/model.py | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/flash/core/data/io/input_base.py b/flash/core/data/io/input_base.py index e815895a24..c53d90dba5 100644 --- a/flash/core/data/io/input_base.py +++ b/flash/core/data/io/input_base.py @@ -253,3 +253,6 @@ def deserialize(self, sample: Any) -> Any: def example_input(self) -> str: raise NotImplementedError + + def __bool__(self): + return True diff --git a/flash/core/model.py b/flash/core/model.py index 073c9baa1c..d58498d0bb 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -809,7 +809,13 @@ def build_data_pipeline( if deserializer is None or type(deserializer) is Deserializer: deserializer = getattr(input_transform, "deserializer", deserializer) - data_pipeline = DataPipeline(input, input_transform, output_transform, deserializer, output) + data_pipeline = DataPipeline( + input=input, + input_transform=input_transform, + output_transform=output_transform, + deserializer=deserializer, + output=output, + ) self._data_pipeline_state = self._data_pipeline_state or DataPipelineState() self.attach_data_pipeline_state(self._data_pipeline_state) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) From 480db7ab86388583ce6489423d3f42faa6663dda Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 2 Dec 2021 13:46:55 +0100 Subject: [PATCH 10/11] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c62417993e..687f8fac08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874)) -- Changed the `Deserializer` subclass `ServeInput` ([#1013](https://github.com/PyTorchLightning/lightning-flash/pull/1013)) +- Changed the `Deserializer` to subclass `ServeInput` ([#1013](https://github.com/PyTorchLightning/lightning-flash/pull/1013)) - Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels`, `Classes`, `FiftyOneLabels`, `Labels`, `Logits`, `Probabilities` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011)) From eee8c0e58bcdc276513142617936bade3a60a31a Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 2 Dec 2021 14:20:22 +0100 Subject: [PATCH 11/11] update --- flash/core/data/batch.py | 2 +- flash/core/data/io/input_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 0c7a8d9ce2..232df337ac 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -43,7 +43,7 @@ def __init__( def forward(self, sample: str): - sample = self.deserializer.deserialize(sample) + sample = self.deserializer(sample) with self._current_stage_context: with self._per_sample_transform_context: diff --git a/flash/core/data/io/input_base.py b/flash/core/data/io/input_base.py index c53d90dba5..1905cbb8f9 100644 --- a/flash/core/data/io/input_base.py +++ b/flash/core/data/io/input_base.py @@ -248,7 +248,7 @@ def __init__( def serve_load_sample(self, sample: Any) -> List[Any]: raise NotImplementedError - def deserialize(self, sample: Any) -> Any: + def __call__(self, sample: Any) -> Any: return self._call_load_sample(sample) def example_input(self) -> str: