Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Data Pipeline V2: Refactor Deserializer into Serve Input #1013

Merged
merged 16 commits into from
Dec 2, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874))

- Changed the `Deserializer` to subclass `ServeInput` ([#1013](https://github.com/PyTorchLightning/lightning-flash/pull/1013))

- Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels`, `Classes`, `FiftyOneLabels`, `Labels`, `Logits`, `Probabilities` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011))

### Deprecated
Expand Down
8 changes: 4 additions & 4 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@


class SpeechRecognitionDeserializer(Deserializer):
def __init__(self, sampling_rate: int = 16000):
super().__init__()

@requires("audio")
def __init__(self, sampling_rate: int = 16000, **kwargs):
super().__init__(**kwargs)
self.sampling_rate = sampling_rate

def deserialize(self, sample: Any) -> Dict:
def serve_load_sample(self, sample: Any) -> Dict:
encoded_with_padding = (sample + "===").encode("ascii")
audio = base64.b64decode(encoded_with_padding)
buffer = io.BytesIO(audio)
Expand Down
3 changes: 2 additions & 1 deletion flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(
self._per_sample_transform_context = CurrentFuncContext("per_sample_transform", input_transform)

def forward(self, sample: str):
sample = self.deserializer(sample)

sample = self.deserializer.deserialize(sample)

with self._current_stage_context:
with self._per_sample_transform_context:
Expand Down
2 changes: 2 additions & 0 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def _resolve_function_hierarchy(
prefixes += ["test"]
elif stage == RunningStage.PREDICTING:
prefixes += ["predict"]
elif stage == RunningStage.SERVING:
prefixes += ["serve"]

prefixes += [None]

Expand Down
26 changes: 25 additions & 1 deletion flash/core/data/io/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import os
import sys
from copy import copy, deepcopy
from typing import Any, cast, Dict, Iterable, MutableMapping, Optional, Sequence, Tuple, Union
from typing import Any, cast, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import Dataset

import flash
Expand Down Expand Up @@ -232,3 +233,26 @@ def __iter__(self):

def __next__(self) -> Any:
return self._call_load_sample(next(self.data_iter))


class ServeInput(Input):
def __init__(
self,
data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None,
) -> None:
if hasattr(self, "serve_load_data"):
raise MisconfigurationException("`serve_load_data` shouldn't be implemented.")

super().__init__(RunningStage.SERVING, data_pipeline_state=data_pipeline_state)

def serve_load_sample(self, sample: Any) -> List[Any]:
raise NotImplementedError

def deserialize(self, sample: Any) -> Any:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self._call_load_sample(sample)

def example_input(self) -> str:
raise NotImplementedError

def __bool__(self):
return True
18 changes: 1 addition & 17 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from abc import abstractmethod
from typing import Any, Mapping
from warnings import warn

from deprecate import deprecated

import flash
from flash.core.data.io.input_base import ServeInput as Deserializer
from flash.core.data.io.output import Output
from flash.core.data.properties import Properties


class Deserializer(Properties):
"""Deserializer."""

def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor???
raise NotImplementedError

@property
@abstractmethod
def example_input(self) -> str:
raise NotImplementedError

def __call__(self, sample: Any) -> Any:
return self.deserialize(sample)


class DeserializerMapping(Deserializer):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
8 changes: 7 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,13 @@ def build_data_pipeline(
if deserializer is None or type(deserializer) is Deserializer:
deserializer = getattr(input_transform, "deserializer", deserializer)

data_pipeline = DataPipeline(input, input_transform, output_transform, deserializer, output)
data_pipeline = DataPipeline(
input=input,
input_transform=input_transform,
output_transform=output_transform,
deserializer=deserializer,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
output=output,
)
self._data_pipeline_state = self._data_pipeline_state or DataPipelineState()
self.attach_data_pipeline_state(self._data_pipeline_state)
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
Expand Down
2 changes: 1 addition & 1 deletion flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def image_loader(filepath: str):

class ImageDeserializer(Deserializer):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
@requires("image")
def deserialize(self, data: str) -> Dict:
def serve_load_sample(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:


class SemanticSegmentationDeserializer(ImageDeserializer):
def deserialize(self, data: str) -> Dict[str, Any]:
result = super().deserialize(data)
def serve_load_sample(self, data: str) -> Dict[str, Any]:
result = super().serve_load_sample(data)
result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT])
result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape[-2:]}
return result
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def parameters(self) -> Dict[str, Any]:
)
return parameters_state.parameters

def deserialize(self, data: str) -> Any:
def serve_load_sample(self, data: str) -> Any:
parameters = self.parameters

df = read_csv(StringIO(data))
Expand Down
2 changes: 1 addition & 1 deletion flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, backbone: str, max_length: int, use_fast: bool = True, **kwar
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=use_fast, **kwargs)
self.max_length = max_length

def deserialize(self, text: str) -> Tensor:
def serve_load_sample(self, text: str) -> Tensor:
return self.tokenizer(text, max_length=self.max_length, truncation=True, padding="max_length")

@property
Expand Down
28 changes: 27 additions & 1 deletion tests/core/data/io/test_input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.data.io.input_base import Input, IterableInput
from flash.core.data.io.input_base import Input, IterableInput, ServeInput
from flash.core.utilities.stages import RunningStage


Expand Down Expand Up @@ -55,3 +56,28 @@ def __init__(self, *args, **kwargs):
self.data = iter([1, 2, 3])

ValidIterableInput(RunningStage.TRAINING)


def test_serve_input():

server_input = ServeInput()
assert server_input.serving
with pytest.raises(NotImplementedError):
server_input._call_load_sample("")

class CustomServeInput(ServeInput):
def serve_load_data(self, data):
raise NotImplementedError

def serve_load_sample(self, data):
return data + 1

with pytest.raises(MisconfigurationException, match="serve_load_data"):
serve_input = CustomServeInput()

class CustomServeInput2(ServeInput):
def serve_load_sample(self, data):
return data + 1

serve_input = CustomServeInput2()
assert serve_input._call_load_sample(1) == 2