diff --git a/CHANGELOG.md b/CHANGELOG.md index 89d200eeff..a4086287ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `backbone` argument from `TextClassificationData`, it is now sufficient to only provide a `backbone` argument to the `TextClassifier` ([#1022](https://github.com/PyTorchLightning/lightning-flash/pull/1022)) +- Removed support for the `serve_sanity_check` argument in `flash.Trainer` ([#1062](https://github.com/PyTorchLightning/lightning-flash/pull/1062)) + ## [0.5.2] - 2021-11-05 ### Added diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index 77f1a97d41..5b489a3aa2 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -13,19 +13,32 @@ # limitations under the License. import os import warnings -from typing import Any, Dict +from typing import Any, Dict, Optional, Type import torch import torch.nn as nn from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding -from flash.audio.speech_recognition.output_transform import SpeechRecognitionBackboneState +from flash.audio.speech_recognition.input import SpeechRecognitionDeserializer +from flash.audio.speech_recognition.output_transform import ( + SpeechRecognitionBackboneState, + SpeechRecognitionOutputTransform, +) +from flash.core.data.input_transform import InputTransform +from flash.core.data.io.input import ServeInput from flash.core.data.states import CollateFn from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _AUDIO_AVAILABLE -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.serve import Composition +from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires +from flash.core.utilities.types import ( + INPUT_TRANSFORM_TYPE, + LR_SCHEDULER_TYPE, + OPTIMIZER_TYPE, + OUTPUT_TRANSFORM_TYPE, + OUTPUT_TYPE, +) if _AUDIO_AVAILABLE: from transformers import Wav2Vec2Processor @@ -55,6 +68,7 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-5, output: OUTPUT_TYPE = None, + output_transform: OUTPUT_TRANSFORM_TYPE = SpeechRecognitionOutputTransform(), ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings @@ -69,6 +83,7 @@ def __init__( lr_scheduler=lr_scheduler, learning_rate=learning_rate, output=output, + output_transform=output_transform, ) self.save_hyperparameters() @@ -86,3 +101,15 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: out = self.model(batch["input_values"], labels=batch["labels"]) out["logs"] = {"loss": out.loss} return out + + @requires("serve") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = SpeechRecognitionDeserializer, + transform: INPUT_TRANSFORM_TYPE = InputTransform, + transform_kwargs: Optional[Dict] = None, + ) -> Composition: + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) diff --git a/flash/audio/speech_recognition/output_transform.py b/flash/audio/speech_recognition/output_transform.py index 228bacb17a..56f282bed5 100644 --- a/flash/audio/speech_recognition/output_transform.py +++ b/flash/audio/speech_recognition/output_transform.py @@ -34,7 +34,6 @@ class SpeechRecognitionBackboneState(ProcessState): class SpeechRecognitionOutputTransform(OutputTransform): - @requires("audio") def __init__(self): super().__init__() @@ -54,6 +53,7 @@ def tokenizer(self): self._backbone = self.backbone return self._tokenizer + @requires("audio") def per_batch_transform(self, batch: Any) -> Any: # converts logits into greedy transcription pred_ids = torch.argmax(batch.logits, dim=-1) @@ -62,9 +62,10 @@ def per_batch_transform(self, batch: Any) -> Any: def __getstate__(self): # TODO: Find out why this is being pickled state = self.__dict__.copy() - state.pop("_tokenizer") + state.pop("_tokenizer", None) return state def __setstate__(self, state): self.__dict__.update(state) - self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone) + if self.backbone is not None: + self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone) diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index fb3d196d35..e67127a1f1 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -11,38 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING import torch from torch import Tensor -from flash.core.data.callback import ControlFlow, FlashCallback -from flash.core.data.utils import convert_to_modules -from flash.core.utilities.stages import RunningStage - if TYPE_CHECKING: - from flash.core.data.input_transform import InputTransform - from flash.core.data.process import Deserializer + from flash.core.data.io.input import ServeInput -class _DeserializeProcessorV2(torch.nn.Module): +class _ServeInputProcessor(torch.nn.Module): def __init__( self, - deserializer: "Deserializer", - input_transform: "InputTransform", - per_sample_transform: Callable, - callbacks: Optional[List[FlashCallback]] = None, + serve_input: "ServeInput", ): super().__init__() - self.input_transform = input_transform - self.callback = ControlFlow(callbacks or []) - self.deserializer = convert_to_modules(deserializer) - self.per_sample_transform = convert_to_modules(per_sample_transform) + self.serve_input = serve_input + self.dataloader_collate_fn = self.serve_input._create_dataloader_collate_fn([]) def forward(self, sample: str): - sample = self.deserializer(sample) - sample = self.per_sample_transform(sample) - self.callback.on_per_sample_transform(sample, RunningStage.SERVING) + sample = self.serve_input._call_load_sample(sample) + sample = self.dataloader_collate_fn(sample) return sample diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 08bc5b509e..400e465f29 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union +from typing import Any, Dict, List, Optional, Set, Type, Union from pytorch_lightning.utilities.exceptions import MisconfigurationException -from flash.core.data.batch import _DeserializeProcessorV2 -from flash.core.data.input_transform import _create_collate_input_transform_processors from flash.core.data.input_transform import InputTransform from flash.core.data.input_transform import InputTransform as NewInputTransform from flash.core.data.io.input import Input, InputBase -from flash.core.data.io.input_transform import _InputTransformProcessorV2 from flash.core.data.io.output import _OutputProcessor, Output from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform from flash.core.data.process import Deserializer @@ -133,21 +130,6 @@ def _is_overridden_recursive( return has_different_code return has_different_code or cls._is_overridden_recursive(method_name, process_obj, super_obj) - @staticmethod - def _identity(samples: Sequence[Any]) -> Sequence[Any]: - return samples - - def deserialize_processor(self) -> _DeserializeProcessorV2: - return _DeserializeProcessorV2( - self._deserializer, - self._input_transform_pipeline, - self._input_transform_pipeline._per_sample_transform, - [], - ) - - def device_input_transform_processor(self, running_stage: RunningStage) -> _InputTransformProcessorV2: - return _create_collate_input_transform_processors(self._input_transform_pipeline, [])[1] - def output_transform_processor(self, running_stage: RunningStage, is_serving=False) -> _OutputTransformProcessor: return self._create_output_transform_processor(running_stage, is_serving=is_serving) diff --git a/flash/core/model.py b/flash/core/model.py index ee6820478e..ee0470b398 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -39,7 +39,7 @@ import flash from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.input_transform import InputTransform -from flash.core.data.io.input import InputBase +from flash.core.data.io.input import InputBase, ServeInput from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Deserializer, DeserializerMapping @@ -354,18 +354,27 @@ def __init__( self.output = output self._wrapped_predict_step = False - def _wrap_predict_step(task, predict_step: Callable) -> Callable: + def _wrap_predict_step(self) -> None: + if not self._wrapped_predict_step: + process_fn = self.build_data_pipeline().output_transform_processor(RunningStage.PREDICTING) - process_fn = task.build_data_pipeline().output_transform_processor(RunningStage.PREDICTING) + predict_step = self.predict_step - @functools.wraps(predict_step) - def wrapper(self, *args, **kwargs): - predictions = predict_step(self, *args, **kwargs) - return process_fn(predictions) + @functools.wraps(predict_step) + def wrapper(*args, **kwargs): + predictions = predict_step(*args, **kwargs) + return process_fn(predictions) - task._wrapped_predict_step = True + self._original_predict_step = self.predict_step + self.predict_step = wrapper - return wrapper + self._wrapped_predict_step = True + + def _unwrap_predict_step(self) -> None: + if self._wrapped_predict_step: + self.predict_step = self._original_predict_step + del self._original_predict_step + self._wrapped_predict_step = False def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: """Implement the core logic for the training/validation/test step. By default this includes: @@ -759,11 +768,6 @@ def build_data_pipeline( self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) return data_pipeline - @torch.jit.unused - @property - def is_servable(self) -> bool: - return type(self.build_data_pipeline()._deserializer) != Deserializer - @torch.jit.unused @property def data_pipeline(self) -> DataPipeline: @@ -803,8 +807,10 @@ def output_transform(self) -> OutputTransform: return getattr(self.data_pipeline, "_output_transform", None) def on_predict_start(self) -> None: - if self.trainer and not self._wrapped_predict_step: - self.predict_step = self._wrap_predict_step(self.predict_step) + self._wrap_predict_step() + + def on_predict_end(self) -> None: + self._unwrap_predict_step() def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # This may be an issue since here we create the same problems with pickle as in @@ -1066,36 +1072,54 @@ def configure_callbacks(self): return [BenchmarkConvergenceCI()] @requires("serve") - def run_serve_sanity_check(self): - if not self.is_servable: - raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.") - + def run_serve_sanity_check(self, serve_input: ServeInput): from fastapi.testclient import TestClient from flash.core.serve.flash_components import build_flash_serve_model_component print("Running serve sanity check") - comp = build_flash_serve_model_component(self) + comp = build_flash_serve_model_component(self, serve_input) composition = Composition(predict=comp, TESTING=True, DEBUG=True) app = composition.serve(host="0.0.0.0", port=8000) with TestClient(app) as tc: - input_str = self.data_pipeline._deserializer.example_input + input_str = serve_input.example_input body = {"session": "UUID", "payload": {"inputs": {"data": input_str}}} resp = tc.post("http://0.0.0.0:8000/predict", json=body) print(f"Sanity check response: {resp.json()}") @requires("serve") - def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> "Composition": - if not self.is_servable: - raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = None, + transform: INPUT_TRANSFORM_TYPE = InputTransform, + transform_kwargs: Optional[Dict] = None, + ) -> "Composition": + """Serve the ``Task``. Override this method to provide a default ``input_cls``, ``transform``, and + ``transform_kwargs``. + Args: + host: The IP address to host the ``Task`` on. + port: The port to host on. + sanity_check: If ``True``, runs a sanity check before serving. + input_cls: The ``ServeInput`` type to use. + transform: The transform to use when serving. + transform_kwargs: Keyword arguments used to instantiate the transform. + """ from flash.core.serve.flash_components import build_flash_serve_model_component + if input_cls is None: + raise NotImplementedError("The `input_cls` must be provided to enable serving.") + + serve_input = input_cls(transform=transform, transform_kwargs=transform_kwargs) + if sanity_check: - self.run_serve_sanity_check() + self.run_serve_sanity_check(serve_input) - comp = build_flash_serve_model_component(self) + comp = build_flash_serve_model_component(self, serve_input) composition = Composition(predict=comp, TESTING=flash._IS_TESTING) composition.serve(host=host, port=port) return composition diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index ae111432d2..6a10919e5c 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -3,6 +3,8 @@ import torch +from flash.core.data.batch import _ServeInputProcessor +from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import DataKeys from flash.core.serve import expose, ModelComponent from flash.core.serve.types.base import BaseType @@ -46,7 +48,18 @@ def deserialize(self, data: str) -> Any: # pragma: no cover return None -def build_flash_serve_model_component(model): +def build_flash_serve_model_component(model, serve_input): + + data_pipeline_state = DataPipelineState() + for properties in [ + serve_input, + getattr(serve_input, "transform", None), + model._output_transform, + model._output, + model, + ]: + if properties is not None and hasattr(properties, "attach_data_pipeline_state"): + properties.attach_data_pipeline_state(data_pipeline_state) data_pipeline = model.build_data_pipeline() @@ -55,23 +68,22 @@ def __init__(self, model): self.model = model self.model.eval() self.data_pipeline = model.build_data_pipeline() - self.deserializer = self.data_pipeline._deserializer - self.dataloader_collate_fn = self.data_pipeline._deserializer._create_dataloader_collate_fn([]) - self.on_after_batch_transfer_fn = self.data_pipeline._deserializer._create_on_after_batch_transfer_fn([]) + self.serve_input = serve_input + self.dataloader_collate_fn = self.serve_input._create_dataloader_collate_fn([]) + self.on_after_batch_transfer_fn = self.serve_input._create_on_after_batch_transfer_fn([]) self.output_transform_processor = self.data_pipeline.output_transform_processor( - RunningStage.PREDICTING, is_serving=True + RunningStage.SERVING, is_serving=True ) # todo (tchaton) Remove this hack self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3 self.device = self.model.device @expose( - inputs={"inputs": FlashInputs(data_pipeline._deserializer._call_load_sample)}, + inputs={"inputs": FlashInputs(_ServeInputProcessor(serve_input))}, outputs={"outputs": FlashOutputs(data_pipeline.output_processor())}, ) def predict(self, inputs): with torch.no_grad(): - inputs = self.dataloader_collate_fn(inputs) if self.extra_arguments: inputs = self.model.transfer_batch_to_device(inputs, self.device, 0) else: diff --git a/flash/core/trainer.py b/flash/core/trainer.py index f969fc8e34..affbb033b9 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -28,7 +28,6 @@ import flash from flash.core.model import Task -from flash.core.utilities.imports import _SERVE_AVAILABLE def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): @@ -73,7 +72,7 @@ def insert_env_defaults(self, *args, **kwargs): class Trainer(PlTrainer): @_defaults_from_env_vars - def __init__(self, *args, serve_sanity_check: bool = False, **kwargs): + def __init__(self, *args, **kwargs): if flash._IS_TESTING: if torch.cuda.is_available(): kwargs["gpus"] = 1 @@ -89,21 +88,6 @@ def __init__(self, *args, serve_sanity_check: bool = False, **kwargs): kwargs["precision"] = 32 super().__init__(*args, **kwargs) - self.serve_sanity_check = serve_sanity_check - - def _run_sanity_check(self, ref_model): - if hasattr(super(), "_run_sanity_check"): - super()._run_sanity_check(ref_model) - - self.run_sanity_check(ref_model) - - def run_sanity_check(self, ref_model): - if hasattr(super(), "run_sanity_check"): - super().run_sanity_check(ref_model) - - if self.serve_sanity_check and ref_model.is_servable and _SERVE_AVAILABLE: - ref_model.run_serve_sanity_check() - # TODO @(tchaton) remove `reset_train_val_dataloaders` from run_train function def _run_train(self) -> None: self._pre_training_routine() diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 47479f5748..870491b5db 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -12,16 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import FunctionType -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from flash.core.classification import ClassificationAdapterTask, LabelsOutput +from flash.core.data.io.input import ServeInput from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.serve import Composition +from flash.core.utilities.imports import requires +from flash.core.utilities.types import ( + INPUT_TRANSFORM_TYPE, + LOSS_FN_TYPE, + LR_SCHEDULER_TYPE, + METRICS_TYPE, + OPTIMIZER_TYPE, + OUTPUT_TYPE, +) from flash.image.classification.adapters import TRAINING_STRATEGIES from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES +from flash.image.classification.transforms import ImageClassificationInputTransform +from flash.image.data import ImageDeserializer class ImageClassifier(ClassificationAdapterTask): @@ -148,6 +160,18 @@ def available_pretrained_weights(cls, backbone: str): return pretrained_weights + @requires("serve") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = ImageDeserializer, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + transform_kwargs: Optional[Dict] = None, + ) -> Composition: + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" if self.hparams.multi_label: diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 37eb2c19c8..c27b589958 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union import torch from torch import nn @@ -19,12 +19,14 @@ from torchmetrics import IoU from flash.core.classification import ClassificationTask -from flash.core.data.io.input import DataKeys +from flash.core.data.io.input import DataKeys, ServeInput from flash.core.data.io.output_transform import OutputTransform from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _KORNIA_AVAILABLE +from flash.core.serve import Composition +from flash.core.utilities.imports import _KORNIA_AVAILABLE, requires from flash.core.utilities.isinstance import _isinstance from flash.core.utilities.types import ( + INPUT_TRANSFORM_TYPE, LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, @@ -34,7 +36,9 @@ ) from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS +from flash.image.segmentation.input import SemanticSegmentationDeserializer from flash.image.segmentation.output import SegmentationLabelsOutput +from flash.image.segmentation.transforms import SemanticSegmentationInputTransform if _KORNIA_AVAILABLE: import kornia as K @@ -174,6 +178,18 @@ def available_pretrained_weights(cls, backbone: str): return pretrained_weights + @requires("serve") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = SemanticSegmentationDeserializer, + transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, + transform_kwargs: Optional[Dict] = None, + ) -> Composition: + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) + @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index a4f18201e9..885a5677eb 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -11,15 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Tuple +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Type import torch from torch.nn import functional as F from flash.core.classification import ClassificationTask, ProbabilitiesOutput -from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.data.input_transform import InputTransform +from flash.core.data.io.input import DataKeys, ServeInput +from flash.core.serve import Composition +from flash.core.utilities.imports import _TABULAR_AVAILABLE, requires +from flash.core.utilities.types import ( + INPUT_TRANSFORM_TYPE, + LR_SCHEDULER_TYPE, + METRICS_TYPE, + OPTIMIZER_TYPE, + OUTPUT_TYPE, +) +from flash.tabular.input import TabularDeserializer if _TABULAR_AVAILABLE: from pytorch_tabnet.tab_network import TabNet @@ -119,3 +129,18 @@ def from_data(cls, datamodule, **kwargs) -> "TabularClassifier": def _ci_benchmark_fn(history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" assert history[-1]["val_accuracy"] > 0.6, history[-1]["val_accuracy"] + + @requires("serve") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = TabularDeserializer, + transform: INPUT_TRANSFORM_TYPE = InputTransform, + transform_kwargs: Optional[Dict] = None, + parameters: Optional[Dict[str, Any]] = None, + ) -> Composition: + return super().serve( + host, port, sanity_check, partial(input_cls, parameters=parameters), transform, transform_kwargs + ) diff --git a/flash/tabular/input.py b/flash/tabular/input.py index 6790978e45..ee1bd7fd0b 100644 --- a/flash/tabular/input.py +++ b/flash/tabular/input.py @@ -171,15 +171,21 @@ def load_data( class TabularDeserializer(Deserializer): + def __init__(self, *args, parameters: Optional[Dict[str, Any]] = None, **kwargs): + self._parameters = parameters + super().__init__(*args, **kwargs) + @property def parameters(self) -> Dict[str, Any]: + if self._parameters is not None: + return self._parameters parameters_state = self.get_state(TabularParametersState) - if parameters_state is None or parameters_state.parameters is None: - raise MisconfigurationException( - "Tabular tasks must previously have been trained in order to support serving as parameters from the " - "train data are required." - ) - return parameters_state.parameters + if parameters_state is not None and parameters_state.parameters is not None: + return parameters_state.parameters + raise MisconfigurationException( + "Tabular tasks must previously have been trained in order to support serving or the `parameters` argument " + "must be provided to the `serve` method." + ) def serve_load_sample(self, data: str) -> Any: parameters = self.parameters diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index 7e2f9d401b..1690a94b25 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -11,15 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Tuple +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Type import torch from torch.nn import functional as F -from flash.core.data.io.input import DataKeys +from flash.core.data.input_transform import InputTransform +from flash.core.data.io.input import DataKeys, ServeInput from flash.core.regression import RegressionTask -from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.serve import Composition +from flash.core.utilities.imports import _TABULAR_AVAILABLE, requires +from flash.core.utilities.types import ( + INPUT_TRANSFORM_TYPE, + LR_SCHEDULER_TYPE, + METRICS_TYPE, + OPTIMIZER_TYPE, + OUTPUT_TYPE, +) +from flash.tabular.input import TabularDeserializer if _TABULAR_AVAILABLE: from pytorch_tabnet.tab_network import TabNet @@ -109,3 +119,18 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def from_data(cls, datamodule, **kwargs) -> "TabularRegressor": model = cls(datamodule.num_features, datamodule.embedding_sizes, **kwargs) return model + + @requires("serve") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = TabularDeserializer, + transform: INPUT_TRANSFORM_TYPE = InputTransform, + transform_kwargs: Optional[Dict] = None, + parameters: Optional[Dict[str, Any]] = None, + ) -> Composition: + return super().serve( + host, port, sanity_check, partial(input_cls, parameters=parameters), transform, transform_kwargs + ) diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index dafe35b125..7dc4d6e220 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -13,18 +13,28 @@ # limitations under the License. import os import warnings -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Type import torch from pytorch_lightning import Callback from flash.core.classification import ClassificationTask, LabelsOutput -from flash.core.data.io.input import DataKeys +from flash.core.data.io.input import DataKeys, ServeInput +from flash.core.integrations.transformers.input_transform import TransformersInputTransform from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.serve import Composition +from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE, requires +from flash.core.utilities.types import ( + INPUT_TRANSFORM_TYPE, + LOSS_FN_TYPE, + LR_SCHEDULER_TYPE, + METRICS_TYPE, + OPTIMIZER_TYPE, + OUTPUT_TYPE, +) from flash.text.classification.backbones import TEXT_CLASSIFIER_BACKBONES +from flash.text.input import TextDeserializer from flash.text.ort_callback import ORTCallback if _TRANSFORMERS_AVAILABLE: @@ -122,3 +132,15 @@ def configure_callbacks(self) -> List[Callback]: if self.enable_ort: callbacks.append(ORTCallback()) return callbacks + + @requires("serve") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = TextDeserializer, + transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + transform_kwargs: Optional[Dict] = None, + ) -> Composition: + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index f4c4e96425..0a8ba55c80 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -13,7 +13,7 @@ # limitations under the License. import os import warnings -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union import torch from pytorch_lightning import Callback @@ -21,14 +21,23 @@ from torch import Tensor from torch.nn import Module -from flash.core.data.io.input import DataKeys +from flash.core.data.io.input import DataKeys, ServeInput from flash.core.data.io.output_transform import OutputTransform +from flash.core.integrations.transformers.input_transform import TransformersInputTransform from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry -from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.core.serve import Composition +from flash.core.utilities.imports import _TEXT_AVAILABLE, requires from flash.core.utilities.providers import _HUGGINGFACE -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE +from flash.core.utilities.types import ( + INPUT_TRANSFORM_TYPE, + LOSS_FN_TYPE, + LR_SCHEDULER_TYPE, + METRICS_TYPE, + OPTIMIZER_TYPE, +) +from flash.text.input import TextDeserializer from flash.text.ort_callback import ORTCallback from flash.text.seq2seq.core.output_transform import Seq2SeqOutputTransform @@ -182,3 +191,15 @@ def configure_callbacks(self) -> List[Callback]: if self.enable_ort: callbacks.append(ORTCallback()) return callbacks + + @requires("serve") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = TextDeserializer, + transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + transform_kwargs: Optional[Dict] = None, + ) -> Composition: + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) diff --git a/flash_examples/serve/image_classification/inference_server.py b/flash_examples/serve/image_classification/inference_server.py index 9b173e1c14..00a519a63b 100644 --- a/flash_examples/serve/image_classification/inference_server.py +++ b/flash_examples/serve/image_classification/inference_server.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash import RunningStage -from flash.image import ImageClassificationInputTransform, ImageClassifier -from flash.image.data import ImageDeserializer +from flash.image import ImageClassifier model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt" ) -model.deserializer = ImageDeserializer(transform=ImageClassificationInputTransform(RunningStage.SERVING)) model.serve() diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index f817a8c561..84929447ca 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -22,8 +22,7 @@ from flash import Trainer from flash.__main__ import main from flash.audio import SpeechRecognition -from flash.audio.speech_recognition.data import InputTransform, SpeechRecognitionData, SpeechRecognitionOutputTransform -from flash.audio.speech_recognition.input import SpeechRecognitionDeserializer +from flash.audio.speech_recognition.data import InputTransform, SpeechRecognitionData from flash.core.data.io.input import DataKeys, Input from flash.core.utilities.imports import _AUDIO_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -83,11 +82,6 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SpeechRecognition(backbone=TEST_BACKBONE) - - model._deserializer = SpeechRecognitionDeserializer(transform=InputTransform(RunningStage.SERVING)) - # TODO: Serve should share the state - model._deserializer.transform._state = model._state - model._output_transform = SpeechRecognitionOutputTransform() model.eval() model.serve() diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 11ad1cb29d..84186996fb 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -23,10 +23,7 @@ from flash.core.classification import ProbabilitiesOutput from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE -from flash.core.utilities.stages import RunningStage from flash.image import ImageClassifier -from flash.image.classification.data import ImageClassificationInputTransform -from flash.image.data import ImageDeserializer from tests.helpers.utils import _IMAGE_TESTING, _SERVE_TESTING # ======== Mock functions ======== @@ -138,7 +135,6 @@ def test_jit(tmpdir, jitter, args): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = ImageClassifier(2) - model._deserializer = ImageDeserializer(transform=ImageClassificationInputTransform(RunningStage.SERVING)) model.eval() model.serve() diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index b6630c0e38..b11334582b 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -24,10 +24,8 @@ from flash.__main__ import main from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE -from flash.core.utilities.stages import RunningStage from flash.image import SemanticSegmentation -from flash.image.segmentation.data import SemanticSegmentationData, SemanticSegmentationInputTransform -from flash.image.segmentation.input import SemanticSegmentationDeserializer +from flash.image.segmentation.data import SemanticSegmentationData from tests.helpers.utils import _IMAGE_TESTING, _SERVE_TESTING # ======== Mock functions ======== @@ -149,10 +147,6 @@ def test_jit(tmpdir, jitter, args): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SemanticSegmentation(2) - - model._deserializer = SemanticSegmentationDeserializer( - transform=SemanticSegmentationInputTransform(RunningStage.SERVING) - ) model.eval() model.serve() diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 9a73dd1476..a0af90c924 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -19,13 +19,11 @@ import torch from pytorch_lightning import Trainer -from flash import InputTransform, RunningStage from flash.__main__ import main from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE from flash.tabular.classification.data import TabularClassificationData from flash.tabular.classification.model import TabularClassifier -from flash.tabular.input import TabularDeserializer from tests.helpers.utils import _SERVE_TESTING, _TABULAR_TESTING # ======== Mock functions ======== @@ -110,11 +108,8 @@ def test_serve(): batch_size=1, ) model = TabularClassifier.from_data(datamodule) - - model._deserializer = TabularDeserializer(transform=InputTransform(RunningStage.SERVING)) - model._deserializer._state = datamodule.train_dataset._state model.eval() - model.serve() + model.serve(parameters=datamodule.parameters) @pytest.mark.skipif(_TABULAR_AVAILABLE, reason="tabular libraries are installed.") diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 148fc5f514..240b99d2f4 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -18,13 +18,11 @@ import pytest import torch -from flash import RunningStage, Trainer +from flash import Trainer from flash.__main__ import main from flash.core.data.io.input import DataKeys -from flash.core.integrations.transformers.input_transform import TransformersInputTransform from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassifier -from flash.text.input import TextDeserializer from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -78,8 +76,6 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TextClassifier(2, TEST_BACKBONE) - - model._deserializer = TextDeserializer(transform=TransformersInputTransform(RunningStage.SERVING)) model.eval() model.serve() diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index d57843815d..e0c0cd52a9 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -18,12 +18,9 @@ import pytest import torch -from flash import DataKeys, RunningStage, Trainer -from flash.core.integrations.transformers.input_transform import TransformersInputTransform +from flash import DataKeys, Trainer from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import SummarizationTask -from flash.text.input import TextDeserializer -from flash.text.seq2seq.core.output_transform import Seq2SeqOutputTransform from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -79,8 +76,6 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SummarizationTask(TEST_BACKBONE) - model._deserializer = TextDeserializer(transform=TransformersInputTransform(RunningStage.SERVING)) - model._output_transform = Seq2SeqOutputTransform() model.eval() model.serve() diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index 05ff3cb5da..732d2ac56b 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -18,12 +18,9 @@ import pytest import torch -from flash import DataKeys, RunningStage, Trainer -from flash.core.integrations.transformers.input_transform import TransformersInputTransform +from flash import DataKeys, Trainer from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TranslationTask -from flash.text.input import TextDeserializer -from flash.text.seq2seq.core.output_transform import Seq2SeqOutputTransform from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -79,9 +76,6 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TranslationTask(TEST_BACKBONE) - model._deserializer = TextDeserializer(transform=TransformersInputTransform(RunningStage.SERVING)) - model._output_transform = Seq2SeqOutputTransform() - model.eval() model.serve()