diff --git a/CHANGELOG.md b/CHANGELOG.md index 6852f83f3e..b57b450719 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed classes named `*Serializer` and properties / variables named `serializer` to be `*Output` and `output` respectively ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) +- Changed `Postprocess` to `OutputTransform` ([#942](https://github.com/PyTorchLightning/lightning-flash/pull/942)) + ### Deprecated - Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst index ae6455c6d8..a5aec6cacc 100644 --- a/docs/source/api/audio.rst +++ b/docs/source/api/audio.rst @@ -33,7 +33,7 @@ __________________ speech_recognition.data.SpeechRecognitionPreprocess speech_recognition.data.SpeechRecognitionBackboneState - speech_recognition.data.SpeechRecognitionPostprocess + speech_recognition.data.SpeechRecognitionOutputTransform speech_recognition.data.SpeechRecognitionCSVDataSource speech_recognition.data.SpeechRecognitionJSONDataSource speech_recognition.data.BaseSpeechRecognition diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 4e46fd1434..08a59a4d30 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -122,7 +122,7 @@ _______________________ ~flash.core.data.process.DefaultPreprocess ~flash.core.data.process.DeserializerMapping ~flash.core.data.process.Deserializer - ~flash.core.data.process.Postprocess + ~flash.core.data.io.output_transform.OutputTransform ~flash.core.data.process.Preprocess flash.core.data.properties diff --git a/docs/source/api/flash.rst b/docs/source/api/flash.rst index bd087b64d7..f61471eeef 100644 --- a/docs/source/api/flash.rst +++ b/docs/source/api/flash.rst @@ -11,7 +11,7 @@ flash ~flash.core.data.data_module.DataModule ~flash.core.data.callback.FlashCallback ~flash.core.data.process.Preprocess - ~flash.core.data.process.Postprocess + ~flash.core.data.io.output_transform.OutputTransform ~flash.core.data.io.output.Output ~flash.core.model.Task ~flash.core.trainer.Trainer diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst index 84e351a74c..3cedc69058 100644 --- a/docs/source/api/image.rst +++ b/docs/source/api/image.rst @@ -101,7 +101,7 @@ ____________ segmentation.data.SemanticSegmentationPathsDataSource segmentation.data.SemanticSegmentationFiftyOneDataSource segmentation.data.SemanticSegmentationDeserializer - segmentation.model.SemanticSegmentationPostprocess + segmentation.model.SemanticSegmentationOutputTransform segmentation.output.FiftyOneSegmentationLabels segmentation.output.SegmentationLabels diff --git a/docs/source/api/tabular.rst b/docs/source/api/tabular.rst index 45fbca2fb3..1b8b8add8b 100644 --- a/docs/source/api/tabular.rst +++ b/docs/source/api/tabular.rst @@ -59,4 +59,4 @@ __________________ ~data.TabularCSVDataSource ~data.TabularDeserializer ~data.TabularPreprocess - ~data.TabularPostprocess + ~data.TabularOutputTransform diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst index 5fd9941e43..8e088bddb0 100644 --- a/docs/source/api/text.rst +++ b/docs/source/api/text.rst @@ -20,7 +20,7 @@ ______________ ~classification.model.TextClassifier ~classification.data.TextClassificationData - classification.data.TextClassificationPostprocess + classification.data.TextClassificationOutputTransform classification.data.TextClassificationPreprocess classification.data.TextDeserializer classification.data.TextDataSource @@ -48,7 +48,7 @@ __________________ question_answering.data.QuestionAnsweringDictionaryDataSource question_answering.data.QuestionAnsweringFileDataSource question_answering.data.QuestionAnsweringJSONDataSource - question_answering.data.QuestionAnsweringPostprocess + question_answering.data.QuestionAnsweringOutputTransform question_answering.data.QuestionAnsweringPreprocess question_answering.data.SQuADDataSource @@ -96,7 +96,7 @@ _______________ seq2seq.core.data.Seq2SeqDataSource seq2seq.core.data.Seq2SeqFileDataSource seq2seq.core.data.Seq2SeqJSONDataSource - seq2seq.core.data.Seq2SeqPostprocess + seq2seq.core.data.Seq2SeqOutputTransform seq2seq.core.data.Seq2SeqPreprocess seq2seq.core.data.Seq2SeqSentencesDataSource seq2seq.core.metrics.BLEUScore diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index c5d51d1f96..ff2ca87fb8 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -26,7 +26,7 @@ Here are common terms you need to be familiar with: * - :class:`~flash.core.data.data_module.DataModule` - The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders. * - :class:`~flash.core.data.data_pipeline.DataPipeline` - - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.io.output.Output` objects. + - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` objects. * - :class:`~flash.core.data.data_source.DataSource` - The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names). * - :class:`~flash.core.data.process.Preprocess` @@ -34,11 +34,11 @@ Here are common terms you need to be familiar with: These hooks (such as :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device). The :class:`~flash.core.data.data_pipeline.DataPipeline` contains a system to call the right hooks when needed. The :class:`~flash.core.data.process.Preprocess` hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform). - * - :class:`~flash.core.data.process.Postprocess` - - The :class:`~flash.core.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic. - The :class:`~flash.core.data.process.Postprocess` hooks cover from model outputs to predictions export. + * - :class:`~flash.core.data.io.output_transform.OutputTransform` + - The :class:`~flash.core.data.io.output_transform.OutputTransform` provides a simple hook-based API to encapsulate your post-processing logic. + The :class:`~flash.core.data.io.output_transform.OutputTransform` hooks cover from model outputs to predictions export. * - :class:`~flash.core.data.io.output.Output` - - The :class:`~flash.core.data.io.output.Output` provides a single :meth:`~flash.core.data.io.output.Output.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction. + - The :class:`~flash.core.data.io.output.Output` provides a single :meth:`~flash.core.data.io.output.Output.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.io.output_transform.OutputTransform`) to the desired output format during prediction. ******************************************* @@ -58,8 +58,8 @@ However, after model training, it requires a lot of engineering overhead to make Usually, extra processing logic should be added to bridge the gap between training data and raw data. The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way. -The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` classes can be used to manage the preprocessing and postprocessing transforms. -The :class:`~flash.core.data.io.output.Output` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.). +The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.io.output_transform.OutputTransform` classes can be used to manage the preprocessing and postprocessing transforms. +The :class:`~flash.core.data.io.output.Output` class provides the logic for converting :class:`~flash.core.data.io.output_transform.OutputTransform` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.). By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms), Flash gives the user much more granular control over their data processing flow. @@ -72,7 +72,7 @@ Here are the primary advantages: To change the processing behavior only on specific stages for a given hook, -you can prefix each of the :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` +you can prefix each of the :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.io.output_transform.OutputTransform` hooks by adding ``train``, ``val``, ``test`` or ``predict``. Check out :class:`~flash.core.data.process.Preprocess` for some examples. @@ -383,17 +383,17 @@ Example:: predictions = lightning_module(data) -Postprocess and Output +OutputTransform and Output __________________________ Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`, the Flash -:class:`~flash.core.data.data_pipeline.DataPipeline` will execute the :class:`~flash.core.data.process.Postprocess` hooks and the +:class:`~flash.core.data.data_pipeline.DataPipeline` will execute the :class:`~flash.core.data.io.output_transform.OutputTransform` hooks and the :class:`~flash.core.data.io.output.Output` behind the scenes. -First, the :meth:`~flash.core.data.process.Postprocess.per_batch_transform` hooks will be applied on the batch predictions. -Then, the :meth:`~flash.core.data.process.Postprocess.uncollate` will split the batch into individual predictions. -Next, the :meth:`~flash.core.data.process.Postprocess.per_sample_transform` will be applied on each prediction. +First, the :meth:`~flash.core.data.io.output_transform.OutputTransform.per_batch_transform` hooks will be applied on the batch predictions. +Then, the :meth:`~flash.core.data.io.output_transform.OutputTransform.uncollate` will split the batch into individual predictions. +Next, the :meth:`~flash.core.data.io.output_transform.OutputTransform.per_sample_transform` will be applied on each prediction. Finally, the :meth:`~flash.core.data.io.output.Output.serialize` method will be called to serialize the predictions. .. note:: The transform can be applied either on device or ``CPU``. @@ -402,7 +402,7 @@ Here is the pseudo-code: Example:: - # This will be wrapped into a :class:`~flash.core.data.batch._Postprocessor` + # This will be wrapped into a :class:`~flash.core.data.batch._OutputTransformProcessor` def uncollate_fn(batch: Any) -> Any: batch = per_batch_transform(batch) diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index 8b8881ce24..1cd7fcd10c 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -11,7 +11,7 @@ Inside `data.py str: ) -class _Postprocessor(torch.nn.Module): - """This class is used to encapsultate the following functions of a Postprocess Object: - - Inside main process: - per_batch_transform: Function to transform a batch - per_sample_transform: Function to transform an individual sample - uncollate_fn: Function to split a batch into samples - per_sample_transform: Function to transform an individual sample - save_fn: Function to save all data - save_per_sample: Function to save an individual sample - is_serving: Whether the Postprocessor is used in serving mode. - """ - - def __init__( - self, - uncollate_fn: Callable, - per_batch_transform: Callable, - per_sample_transform: Callable, - output: Optional[Callable], - save_fn: Optional[Callable] = None, - save_per_sample: bool = False, - is_serving: bool = False, - ): - super().__init__() - self.uncollate_fn = convert_to_modules(uncollate_fn) - self.per_batch_transform = convert_to_modules(per_batch_transform) - self.per_sample_transform = convert_to_modules(per_sample_transform) - self.output = convert_to_modules(output) - self.save_fn = convert_to_modules(save_fn) - self.save_per_sample = convert_to_modules(save_per_sample) - self.is_serving = is_serving - - @staticmethod - def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]: - metadata = None - if isinstance(batch, Mapping) and DefaultDataKeys.METADATA in batch: - metadata = batch.pop(DefaultDataKeys.METADATA, None) - return batch, metadata - - def forward(self, batch: Sequence[Any]): - batch, metadata = self._extract_metadata(batch) - uncollated = self.uncollate_fn(self.per_batch_transform(batch)) - if metadata: - for sample, sample_metadata in zip(uncollated, metadata): - sample[DefaultDataKeys.METADATA] = sample_metadata - - final_preds = [self.per_sample_transform(sample) for sample in uncollated] - - if self.output is not None: - final_preds = [self.output(sample) for sample in final_preds] - - if isinstance(uncollated, Tensor) and isinstance(final_preds[0], Tensor): - final_preds = torch.stack(final_preds) - else: - final_preds = type(final_preds)(final_preds) - - if self.save_fn: - if self.save_per_sample: - for pred in final_preds: - self.save_fn(pred) - else: - self.save_fn(final_preds) - return final_preds - - def __str__(self) -> str: - return ( - "_Postprocessor:\n" - f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" - f"\t(uncollate_fn): {str(self.uncollate_fn)}\n" - f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" - f"\t(output): {str(self.output)}" - ) - - def default_uncollate(batch: Any): """ This function is used to uncollate a batch into samples. diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index db9e00aff7..3d89d4bdef 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -39,8 +39,9 @@ from flash.core.data.auto_dataset import BaseAutoDataset, IterableAutoDataset from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher -from flash.core.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess +from flash.core.data.data_pipeline import DataPipeline, DefaultPreprocess, Preprocess from flash.core.data.data_source import DataSource, DefaultDataSources +from flash.core.data.io.output_transform import OutputTransform from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, requires @@ -55,7 +56,8 @@ class DataModule(pl.LightningDataModule): """A basic DataModule class for all Flash tasks. This class includes references to a :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, - :class:`~flash.core.data.process.Postprocess`, and a :class:`~flash.core.data.callback.BaseDataFetcher`. + :class:`~flash.core.data.io.output_transform.OutputTransform`, and a + :class:`~flash.core.data.callback.BaseDataFetcher`. Args: train_dataset: Dataset for training. Defaults to None. @@ -66,9 +68,9 @@ class DataModule(pl.LightningDataModule): preprocess: The :class:`~flash.core.data.process.Preprocess` to use when constructing the :class:`~flash.core.data.data_pipeline.DataPipeline`. If ``None``, a :class:`~flash.core.data.process.DefaultPreprocess` will be used. - postprocess: The :class:`~flash.core.data.process.Postprocess` to use when constructing the + output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to use when constructing the :class:`~flash.core.data.data_pipeline.DataPipeline`. If ``None``, a plain - :class:`~flash.core.data.process.Postprocess` will be used. + :class:`~flash.core.data.io.output_transform.OutputTransform` will be used. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to attach to the :class:`~flash.core.data.process.Preprocess`. If ``None``, the output from :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used. @@ -83,7 +85,7 @@ class DataModule(pl.LightningDataModule): """ preprocess_cls = DefaultPreprocess - postprocess_cls = Postprocess + output_transform_cls = OutputTransform def __init__( self, @@ -93,7 +95,7 @@ def __init__( predict_dataset: Optional[Dataset] = None, data_source: Optional[DataSource] = None, preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, + output_transform: Optional[OutputTransform] = None, data_fetcher: Optional[BaseDataFetcher] = None, val_split: Optional[float] = None, batch_size: int = 4, @@ -108,7 +110,7 @@ def __init__( self._data_source: DataSource = data_source self._preprocess: Optional[Preprocess] = preprocess - self._postprocess: Optional[Postprocess] = postprocess + self._output_transform: Optional[OutputTransform] = output_transform self._viz: Optional[BaseVisualization] = None self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() @@ -433,15 +435,16 @@ def preprocess(self) -> Preprocess: return self._preprocess or self.preprocess_cls() @property - def postprocess(self) -> Postprocess: - """Property that returns the postprocessing class used on the input data.""" - return self._postprocess or self.postprocess_cls() + def output_transform(self) -> OutputTransform: + """Property that returns the :class:`~flash.core.data.io.output_transform.OutputTransform` used to + output_transform the model outputs.""" + return self._output_transform or self.output_transform_cls() @property def data_pipeline(self) -> DataPipeline: """Property that returns the full data pipeline including the data source, preprocessing and postprocessing.""" - return DataPipeline(self.data_source, self.preprocess, self.postprocess) + return DataPipeline(self.data_source, self.preprocess, self.output_transform) def available_data_sources(self) -> Sequence[str]: """Get the list of available data source names for use with this diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 2dcdcb7294..bf21944adf 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -24,12 +24,13 @@ import flash from flash.core.data.auto_dataset import IterableAutoDataset -from flash.core.data.batch import _DeserializeProcessor, _Postprocessor, _Preprocessor, _Sequential +from flash.core.data.batch import _DeserializeProcessor, _Preprocessor, _Sequential from flash.core.data.data_source import DataSource from flash.core.data.io.output import _OutputProcessor, Output -from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess +from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform +from flash.core.data.process import DefaultPreprocess, Deserializer, Preprocess from flash.core.data.properties import ProcessState -from flash.core.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX +from flash.core.data.utils import _OUTPUT_TRANSFORM_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0 from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage @@ -76,53 +77,38 @@ def __str__(self) -> str: class DataPipeline: """ DataPipeline holds the engineering logic to connect - :class:`~flash.core.data.process.Preprocess` and/or :class:`~flash.core.data.process.Postprocess` objects to - the ``DataModule``, Flash ``Task`` and ``Trainer``. - - Example:: - - class CustomPreprocess(Preprocess): - pass - - class CustomPostprocess(Postprocess): - pass - - custom_data_pipeline = DataPipeline(CustomPreprocess(), CustomPostprocess()) - - # And it can attached to both the datamodule and model. - - datamodule.data_pipeline = custom_data_pipeline - model.data_pipeline = custom_data_pipeline + :class:`~flash.core.data.process.Preprocess` and/or :class:`~flash.core.data.io.output_transform.OutputTransform` + objects to the ``DataModule``, Flash ``Task`` and ``Trainer``. """ PREPROCESS_FUNCS: Set[str] = _PREPROCESS_FUNCS - POSTPROCESS_FUNCS: Set[str] = _POSTPROCESS_FUNCS + OUTPUT_TRANSFORM_FUNCS: Set[str] = _OUTPUT_TRANSFORM_FUNCS def __init__( self, data_source: Optional[DataSource] = None, preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, + output_transform: Optional[OutputTransform] = None, deserializer: Optional[Deserializer] = None, output: Optional[Output] = None, ) -> None: self.data_source = data_source self._preprocess_pipeline = preprocess or DefaultPreprocess() - self._postprocess_pipeline = postprocess or Postprocess() + self._output_transform = output_transform or OutputTransform() self._output = output or Output() self._deserializer = deserializer or Deserializer() self._running_stage = None def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> DataPipelineState: """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.Preprocess`, - :class:`.Postprocess`, and :class:`.Output`. Once this has been called, any attempt to add new state will + :class:`.OutputTransform`, and :class:`.Output`. Once this has been called, any attempt to add new state will give a warning.""" data_pipeline_state = data_pipeline_state or DataPipelineState() if self.data_source is not None: self.data_source.attach_data_pipeline_state(data_pipeline_state) self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) - self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) + self._output_transform.attach_data_pipeline_state(data_pipeline_state) self._output.attach_data_pipeline_state(data_pipeline_state) return data_pipeline_state @@ -179,8 +165,8 @@ def worker_preprocessor( def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor: return self._create_collate_preprocessors(running_stage)[2] - def postprocessor(self, running_stage: RunningStage, is_serving=False) -> _Postprocessor: - return self._create_uncollate_postprocessors(running_stage, is_serving=is_serving) + def output_transform_processor(self, running_stage: RunningStage, is_serving=False) -> _OutputTransformProcessor: + return self._create_output_transform_processor(running_stage, is_serving=is_serving) def output_processor(self) -> _OutputProcessor: return _OutputProcessor(self._output) @@ -315,13 +301,15 @@ def _model_transfer_to_device_wrapper( return func @staticmethod - def _model_predict_step_wrapper(func: Callable, postprocessor: _Postprocessor, model: "Task") -> Callable: + def _model_predict_step_wrapper( + func: Callable, output_transform_processor: _OutputTransformProcessor, model: "Task" + ) -> Callable: if not isinstance(func, _StageOrchestrator): _original = func func = _StageOrchestrator(func, model) func._original = _original - func.register_additional_stage(RunningStage.PREDICTING, postprocessor) + func.register_additional_stage(RunningStage.PREDICTING, output_transform_processor) return func @@ -448,50 +436,50 @@ def _attach_preprocess_to_model( model.transfer_batch_to_device, device_collate_fn, model, stage ) - def _create_uncollate_postprocessors( + def _create_output_transform_processor( self, stage: RunningStage, is_serving: bool = False, - ) -> _Postprocessor: + ) -> _OutputTransformProcessor: save_per_sample = None save_fn = None - postprocess: Postprocess = self._postprocess_pipeline + output_transform: OutputTransform = self._output_transform func_names: Dict[str, str] = { - k: self._resolve_function_hierarchy(k, postprocess, stage, object_type=Postprocess) - for k in self.POSTPROCESS_FUNCS + k: self._resolve_function_hierarchy(k, output_transform, stage, object_type=OutputTransform) + for k in self.OUTPUT_TRANSFORM_FUNCS } # since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here. - if postprocess._save_path: + if output_transform._save_path: save_per_sample: bool = self._is_overriden_recursive( - "save_sample", postprocess, Postprocess, prefix=_STAGES_PREFIX[stage] + "save_sample", output_transform, OutputTransform, prefix=_STAGES_PREFIX[stage] ) if save_per_sample: - save_per_sample: Callable = getattr(postprocess, func_names["save_sample"]) + save_per_sample: Callable = getattr(output_transform, func_names["save_sample"]) else: - save_fn: Callable = getattr(postprocess, func_names["save_data"]) + save_fn: Callable = getattr(output_transform, func_names["save_data"]) - return _Postprocessor( - getattr(postprocess, func_names["uncollate"]), - getattr(postprocess, func_names["per_batch_transform"]), - getattr(postprocess, func_names["per_sample_transform"]), + return _OutputTransformProcessor( + getattr(output_transform, func_names["uncollate"]), + getattr(output_transform, func_names["per_batch_transform"]), + getattr(output_transform, func_names["per_sample_transform"]), output=None if is_serving else self._output, save_fn=save_fn, save_per_sample=save_per_sample, is_serving=is_serving, ) - def _attach_postprocess_to_model( + def _attach_output_transform_to_model( self, model: "Task", stage: RunningStage, is_serving: bool = False, ) -> "Task": model.predict_step = self._model_predict_step_wrapper( - model.predict_step, self._create_uncollate_postprocessors(stage, is_serving=is_serving), model + model.predict_step, self._create_output_transform_processor(stage, is_serving=is_serving), model ) return model @@ -505,13 +493,13 @@ def _attach_to_model( self._attach_preprocess_to_model(model, stage) if not stage or stage == RunningStage.PREDICTING: - self._attach_postprocess_to_model(model, RunningStage.PREDICTING, is_serving=is_serving) + self._attach_output_transform_to_model(model, RunningStage.PREDICTING, is_serving=is_serving) def _detach_from_model(self, model: "Task", stage: Optional[RunningStage] = None): self._detach_preprocessing_from_model(model, stage) if not stage or stage == RunningStage.PREDICTING: - self._detach_postprocess_from_model(model) + self._detach_output_transform_from_model(model) def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[RunningStage] = None): if not stage: @@ -578,7 +566,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin self._set_loader(model, whole_attr_name, dataloader) @staticmethod - def _detach_postprocess_from_model(model: "Task"): + def _detach_output_transform_from_model(model: "Task"): if hasattr(model.predict_step, "_original"): # don't delete the predict_step here since we don't know @@ -588,7 +576,7 @@ def _detach_postprocess_from_model(model: "Task"): def __str__(self) -> str: data_source: DataSource = self.data_source preprocess: Preprocess = self._preprocess_pipeline - postprocess: Postprocess = self._postprocess_pipeline + output_transform: OutputTransform = self._output_transform output: Output = self._output deserializer: Deserializer = self._deserializer return ( @@ -596,7 +584,7 @@ def __str__(self) -> str: f"data_source={str(data_source)}, " f"deserializer={deserializer}, " f"preprocess={preprocess}, " - f"postprocess={postprocess}, " + f"output_transform={output_transform}, " f"output={output})" ) diff --git a/flash/core/data/io/output.py b/flash/core/data/io/output.py index ce2cd9ef4b..816cb01213 100644 --- a/flash/core/data/io/output.py +++ b/flash/core/data/io/output.py @@ -28,7 +28,7 @@ def transform(sample: Any) -> Any: """Convert the given sample into the desired output format. Args: - sample: The output from the :class:`.Postprocess`. + sample: The output from the :class:`.OutputTransform`. Returns: The converted output. diff --git a/flash/core/data/io/output_transform.py b/flash/core/data/io/output_transform.py new file mode 100644 index 0000000000..44aa4f926a --- /dev/null +++ b/flash/core/data/io/output_transform.py @@ -0,0 +1,153 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple + +import torch +from torch import Tensor + +from flash.core.data.batch import default_uncollate +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.properties import Properties +from flash.core.data.utils import convert_to_modules + + +class OutputTransform(Properties): + """The :class:`~flash.core.data.io.output_transform.OutputTransform` encapsulates all the data processing logic + that should run after the model.""" + + def __init__(self, save_path: Optional[str] = None): + super().__init__() + self._saved_samples = 0 + self._save_path = save_path + + @staticmethod + def per_batch_transform(batch: Any) -> Any: + """Transforms to apply on a whole batch before uncollation to individual samples. + + Can involve both CPU and Device transforms as this is not applied in separate workers. + """ + return batch + + @staticmethod + def per_sample_transform(sample: Any) -> Any: + """Transforms to apply to a single sample after splitting up the batch. + + Can involve both CPU and Device transforms as this is not applied in separate workers. + """ + return sample + + @staticmethod + def uncollate(batch: Any) -> Any: + """Uncollates a batch into single samples. + + Tries to preserve the type whereever possible. + """ + return default_uncollate(batch) + + @staticmethod + def save_data(data: Any, path: str) -> None: + """Saves all data together to a single path.""" + torch.save(data, path) + + @staticmethod + def save_sample(sample: Any, path: str) -> None: + """Saves each sample individually to a given path.""" + torch.save(sample, path) + + # TODO: Are those needed ? + def format_sample_save_path(self, path: str) -> str: + path = os.path.join(path, f"sample_{self._saved_samples}.ptl") + self._saved_samples += 1 + return path + + def _save_data(self, data: Any) -> None: + self.save_data(data, self._save_path) + + def _save_sample(self, sample: Any) -> None: + self.save_sample(sample, self.format_sample_save_path(self._save_path)) + + +class _OutputTransformProcessor(torch.nn.Module): + """This class is used to encapsultate the following functions of a OutputTransform Object: + + Inside main process: + per_batch_transform: Function to transform a batch + per_sample_transform: Function to transform an individual sample + uncollate_fn: Function to split a batch into samples + per_sample_transform: Function to transform an individual sample + save_fn: Function to save all data + save_per_sample: Function to save an individual sample + is_serving: Whether the Postprocessor is used in serving mode. + """ + + def __init__( + self, + uncollate_fn: Callable, + per_batch_transform: Callable, + per_sample_transform: Callable, + output: Optional[Callable], + save_fn: Optional[Callable] = None, + save_per_sample: bool = False, + is_serving: bool = False, + ): + super().__init__() + self.uncollate_fn = convert_to_modules(uncollate_fn) + self.per_batch_transform = convert_to_modules(per_batch_transform) + self.per_sample_transform = convert_to_modules(per_sample_transform) + self.output = convert_to_modules(output) + self.save_fn = convert_to_modules(save_fn) + self.save_per_sample = convert_to_modules(save_per_sample) + self.is_serving = is_serving + + @staticmethod + def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]: + metadata = None + if isinstance(batch, Mapping) and DefaultDataKeys.METADATA in batch: + metadata = batch.pop(DefaultDataKeys.METADATA, None) + return batch, metadata + + def forward(self, batch: Sequence[Any]): + batch, metadata = self._extract_metadata(batch) + uncollated = self.uncollate_fn(self.per_batch_transform(batch)) + if metadata: + for sample, sample_metadata in zip(uncollated, metadata): + sample[DefaultDataKeys.METADATA] = sample_metadata + + final_preds = [self.per_sample_transform(sample) for sample in uncollated] + + if self.output is not None: + final_preds = [self.output(sample) for sample in final_preds] + + if isinstance(uncollated, Tensor) and isinstance(final_preds[0], Tensor): + final_preds = torch.stack(final_preds) + else: + final_preds = type(final_preds)(final_preds) + + if self.save_fn: + if self.save_per_sample: + for pred in final_preds: + self.save_fn(pred) + else: + self.save_fn(final_preds) + return final_preds + + def __str__(self) -> str: + return ( + "_OutputTransformProcessor:\n" + f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" + f"\t(uncollate_fn): {str(self.uncollate_fn)}\n" + f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" + f"\t(output): {str(self.output)}" + ) diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py index f9aa259df9..ed309646f2 100644 --- a/flash/core/data/new_data_module.py +++ b/flash/core/data/new_data_module.py @@ -26,9 +26,10 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DefaultPreprocess, Postprocess +from flash.core.data.data_pipeline import DefaultPreprocess from flash.core.data.datasets import BaseDataset from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform +from flash.core.data.io.output_transform import OutputTransform from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage @@ -56,7 +57,7 @@ class DataModule(DataModule): """ preprocess_cls = DefaultPreprocess - postprocess_cls = Postprocess + output_transform_cls = OutputTransform flash_datasets_registry = FlashRegistry("datasets") def __init__( @@ -80,7 +81,7 @@ def __init__( if flash._IS_TESTING and torch.cuda.is_available(): batch_size = 16 - self._postprocess: Optional[Postprocess] = None + self._output_transform: Optional[OutputTransform] = None self._viz: Optional[BaseVisualization] = None self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() diff --git a/flash/core/data/process.py b/flash/core/data/process.py index a9b56d312b..4396171923 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -13,7 +13,6 @@ # limitations under the License. import functools import inspect -import os from abc import ABC, abstractclassmethod, abstractmethod from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union @@ -25,7 +24,6 @@ from torch.utils.data._utils.collate import default_collate import flash -from flash.core.data.batch import default_uncollate from flash.core.data.callback import FlashCallback from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.io.output import Output @@ -530,62 +528,6 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): return cls(**state_dict) -class Postprocess(Properties): - """The :class:`~flash.core.data.process.Postprocess` encapsulates all the data processing logic that should run - after the model.""" - - def __init__(self, save_path: Optional[str] = None): - super().__init__() - self._saved_samples = 0 - self._save_path = save_path - - @staticmethod - def per_batch_transform(batch: Any) -> Any: - """Transforms to apply on a whole batch before uncollation to individual samples. - - Can involve both CPU and Device transforms as this is not applied in separate workers. - """ - return batch - - @staticmethod - def per_sample_transform(sample: Any) -> Any: - """Transforms to apply to a single sample after splitting up the batch. - - Can involve both CPU and Device transforms as this is not applied in separate workers. - """ - return sample - - @staticmethod - def uncollate(batch: Any) -> Any: - """Uncollates a batch into single samples. - - Tries to preserve the type whereever possible. - """ - return default_uncollate(batch) - - @staticmethod - def save_data(data: Any, path: str) -> None: - """Saves all data together to a single path.""" - torch.save(data, path) - - @staticmethod - def save_sample(sample: Any, path: str) -> None: - """Saves each sample individually to a given path.""" - torch.save(sample, path) - - # TODO: Are those needed ? - def format_sample_save_path(self, path: str) -> str: - path = os.path.join(path, f"sample_{self._saved_samples}.ptl") - self._saved_samples += 1 - return path - - def _save_data(self, data: Any) -> None: - self.save_data(data, self._save_path) - - def _save_sample(self, sample: Any) -> None: - self.save_sample(sample, self.format_sample_save_path(self._save_path)) - - class Deserializer(Properties): """Deserializer.""" diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 761ea6c6b1..46342a2eb3 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -53,7 +53,7 @@ *_PREPROCESS_FUNCS, } -_POSTPROCESS_FUNCS: Set[str] = { +_OUTPUT_TRANSFORM_FUNCS: Set[str] = { "per_batch_transform", "uncollate", "per_sample_transform", diff --git a/flash/core/model.py b/flash/core/model.py index 405f156fc8..ffe7063ea8 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -40,7 +40,8 @@ from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource from flash.core.data.io.output import Output -from flash.core.data.process import Deserializer, DeserializerMapping, Postprocess, Preprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import Deserializer, DeserializerMapping, Preprocess from flash.core.data.properties import ProcessState from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY from flash.core.optimizers.schedulers import _SCHEDULERS_REGISTRY @@ -57,8 +58,8 @@ METRICS_TYPE, MODEL_TYPE, OPTIMIZER_TYPE, + OUTPUT_TRANSFORM_TYPE, OUTPUT_TYPE, - POSTPROCESS_TYPE, PREPROCESS_TYPE, ) @@ -318,7 +319,8 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check deserializer: Either a single :class:`~flash.core.data.process.Deserializer` or a mapping of these to deserialize the input preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task. - postprocess: :class:`~flash.core.data.process.Postprocess` to use as the default for this task. + output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to use as the default for this + task. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ @@ -337,7 +339,7 @@ def __init__( metrics: METRICS_TYPE = None, deserializer: DESERIALIZER_TYPE = None, preprocess: PREPROCESS_TYPE = None, - postprocess: POSTPROCESS_TYPE = None, + output_transform: OUTPUT_TRANSFORM_TYPE = None, output: OUTPUT_TYPE = None, ): super().__init__() @@ -356,7 +358,7 @@ def __init__( self._deserializer: Optional[Deserializer] = None self._preprocess: Optional[Preprocess] = preprocess - self._postprocess: Optional[Postprocess] = postprocess + self._output_transform: Optional[OutputTransform] = output_transform self._output: Optional[Output] = None # Explicitly set the output to call the setter @@ -497,7 +499,7 @@ def predict( x = data_pipeline.device_preprocessor(running_stage)(x) x = x[0] if isinstance(x, list) else x predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict` - predictions = data_pipeline.postprocessor(running_stage)(predictions) + predictions = data_pipeline.output_transform_processor(running_stage)(predictions) return predictions def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: @@ -568,29 +570,31 @@ def configure_finetune_callback() -> List[Callback]: def _resolve( old_deserializer: Optional[Deserializer], old_preprocess: Optional[Preprocess], - old_postprocess: Optional[Postprocess], + old_output_transform: Optional[OutputTransform], old_output: Optional[Output], new_deserializer: Optional[Deserializer], new_preprocess: Optional[Preprocess], - new_postprocess: Optional[Postprocess], + new_output_transform: Optional[OutputTransform], new_output: Optional[Output], - ) -> Tuple[Optional[Deserializer], Optional[Preprocess], Optional[Postprocess], Optional[Output]]: - """Resolves the correct :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and - :class:`~flash.core.data.io.output.Output` to use, choosing ``new_*`` if it is not None or a base class - (:class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, or - :class:`~flash.core.data.io.output.Output`) and ``old_*`` otherwise. + ) -> Tuple[Optional[Deserializer], Optional[Preprocess], Optional[OutputTransform], Optional[Output]]: + """Resolves the correct :class:`~flash.core.data.process.Preprocess`, + :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` to + use, choosing ``new_*`` if it is not None or a base class (:class:`~flash.core.data.process.Preprocess`, + :class:`~flash.core.data.io.output_transform.OutputTransform`, or :class:`~flash.core.data.io.output.Output`) + and ``old_*`` otherwise. Args: old_preprocess: :class:`~flash.core.data.process.Preprocess` to be overridden. - old_postprocess: :class:`~flash.core.data.process.Postprocess` to be overridden. + old_output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to be overridden. old_output: :class:`~flash.core.data.io.output.Output` to be overridden. new_preprocess: :class:`~flash.core.data.process.Preprocess` to override with. - new_postprocess: :class:`~flash.core.data.process.Postprocess` to override with. + new_output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to override with. new_output: :class:`~flash.core.data.io.output.Output` to override with. Returns: - The resolved :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, - and :class:`~flash.core.data.io.output.Output`. + The resolved :class:`~flash.core.data.process.Preprocess`, + :class:`~flash.core.data.io.output_transform.OutputTransform`, and + :class:`~flash.core.data.io.output.Output`. """ deserializer = old_deserializer if new_deserializer is not None and type(new_deserializer) != Deserializer: @@ -600,15 +604,15 @@ def _resolve( if new_preprocess is not None and type(new_preprocess) != Preprocess: preprocess = new_preprocess - postprocess = old_postprocess - if new_postprocess is not None and type(new_postprocess) != Postprocess: - postprocess = new_postprocess + output_transform = old_output_transform + if new_output_transform is not None and type(new_output_transform) != OutputTransform: + output_transform = new_output_transform output = old_output if new_output is not None and type(new_output) != Output: output = new_output - return deserializer, preprocess, postprocess, output + return deserializer, preprocess, output_transform, output @torch.jit.unused @property @@ -669,7 +673,7 @@ def build_data_pipeline( data_pipeline: Optional[DataPipeline] = None, ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available - :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` + :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.io.output_transform.OutputTransform` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. @@ -682,12 +686,13 @@ def build_data_pipeline( the current data source format used. deserializer: deserializer to use data_pipeline: Optional highest priority source of - :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess`. + :class:`~flash.core.data.process.Preprocess` and + :class:`~flash.core.data.io.output_transform.OutputTransform`. Returns: The fully resolved :class:`.DataPipeline`. """ - deserializer, old_data_source, preprocess, postprocess, output = None, None, None, None, None + deserializer, old_data_source, preprocess, output_transform, output = None, None, None, None, None # Datamodule datamodule = None @@ -699,32 +704,32 @@ def build_data_pipeline( if getattr(datamodule, "data_pipeline", None) is not None: old_data_source = getattr(datamodule.data_pipeline, "data_source", None) preprocess = getattr(datamodule.data_pipeline, "_preprocess_pipeline", None) - postprocess = getattr(datamodule.data_pipeline, "_postprocess_pipeline", None) + output_transform = getattr(datamodule.data_pipeline, "_output_transform", None) output = getattr(datamodule.data_pipeline, "_output", None) deserializer = getattr(datamodule.data_pipeline, "_deserializer", None) # Defaults / task attributes - deserializer, preprocess, postprocess, output = Task._resolve( + deserializer, preprocess, output_transform, output = Task._resolve( deserializer, preprocess, - postprocess, + output_transform, output, self._deserializer, self._preprocess, - self._postprocess, + self._output_transform, self._output, ) # Datapipeline if data_pipeline is not None: - deserializer, preprocess, postprocess, output = Task._resolve( + deserializer, preprocess, output_transform, output = Task._resolve( deserializer, preprocess, - postprocess, + output_transform, output, getattr(data_pipeline, "_deserializer", None), getattr(data_pipeline, "_preprocess_pipeline", None), - getattr(data_pipeline, "_postprocess_pipeline", None), + getattr(data_pipeline, "_output_transform", None), getattr(data_pipeline, "_output", None), ) @@ -739,7 +744,7 @@ def build_data_pipeline( if deserializer is None or type(deserializer) is Deserializer: deserializer = getattr(preprocess, "deserializer", deserializer) - data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, output) + data_pipeline = DataPipeline(data_source, preprocess, output_transform, deserializer, output) self._data_pipeline_state = self._data_pipeline_state or DataPipelineState() self.attach_data_pipeline_state(self._data_pipeline_state) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) @@ -763,14 +768,14 @@ def data_pipeline(self) -> DataPipeline: @torch.jit.unused @data_pipeline.setter def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: - self._deserializer, self._preprocess, self._postprocess, self.output = Task._resolve( + self._deserializer, self._preprocess, self._output_transform, self.output = Task._resolve( self._deserializer, self._preprocess, - self._postprocess, + self._output_transform, self._output, getattr(data_pipeline, "_deserializer", None), getattr(data_pipeline, "_preprocess_pipeline", None), - getattr(data_pipeline, "_postprocess_pipeline", None), + getattr(data_pipeline, "_output_transform", None), getattr(data_pipeline, "_output", None), ) @@ -785,8 +790,8 @@ def preprocess(self) -> Preprocess: @torch.jit.unused @property - def postprocess(self) -> Postprocess: - return getattr(self.data_pipeline, "_postprocess_pipeline", None) + def output_transform(self) -> OutputTransform: + return getattr(self.data_pipeline, "_output_transform", None) def on_train_dataloader(self) -> None: if self.data_pipeline is not None: diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index 6a70edd41d..70f36879a5 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -57,7 +57,9 @@ def __init__(self, model): self.data_pipeline = model.build_data_pipeline() self.worker_preprocessor = self.data_pipeline.worker_preprocessor(RunningStage.PREDICTING, is_serving=True) self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING) - self.postprocessor = self.data_pipeline.postprocessor(RunningStage.PREDICTING, is_serving=True) + self.output_transform_processor = self.data_pipeline.output_transform_processor( + RunningStage.PREDICTING, is_serving=True + ) # todo (tchaton) Remove this hack self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3 self.device = self.model.device @@ -75,7 +77,7 @@ def predict(self, inputs): inputs = self.model.transfer_batch_to_device(inputs, self.device) inputs = self.device_preprocessor(inputs) preds = self.model.predict_step(inputs, 0) - preds = self.postprocessor(preds) + preds = self.output_transform_processor(preds) return preds return FlashServeModelComponent(model) diff --git a/flash/core/utilities/types.py b/flash/core/utilities/types.py index ec968792d8..21f45be693 100644 --- a/flash/core/utilities/types.py +++ b/flash/core/utilities/types.py @@ -4,7 +4,8 @@ from torchmetrics import Metric from flash.core.data.io.output import Output -from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import Deserializer, Preprocess MODEL_TYPE = Optional[nn.Module] LOSS_FN_TYPE = Optional[Union[Callable, Mapping, Sequence]] @@ -15,5 +16,5 @@ METRICS_TYPE = Union[Metric, Mapping, Sequence, None] DESERIALIZER_TYPE = Optional[Union[Deserializer, Mapping[str, Deserializer]]] PREPROCESS_TYPE = Optional[Preprocess] -POSTPROCESS_TYPE = Optional[Postprocess] +OUTPUT_TRANSFORM_TYPE = Optional[OutputTransform] OUTPUT_TYPE = Optional[Output] diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index f926c24538..c6802766a4 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -18,7 +18,8 @@ from torch.utils.data import Dataset from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Postprocess, Preprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import Preprocess from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.data import ImagePathsDataSource @@ -146,7 +147,7 @@ def default_transforms(self) -> Dict[str, Callable]: } -class FaceDetectionPostProcess(Postprocess): +class FaceDetectionOutputTransform(OutputTransform): """Generates preds from model output.""" @staticmethod @@ -169,4 +170,4 @@ def per_batch_transform(batch: Any) -> Any: class FaceDetectionData(ObjectDetectionData): preprocess_cls = FaceDetectionPreprocess - postprocess_cls = FaceDetectionPostProcess + output_transform_cls = FaceDetectionOutputTransform diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index bfb2e52c18..9246aaec45 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -115,11 +115,11 @@ def get_model( model.register_buffer("mean", getattr(pl_model, "mean")) model.register_buffer("std", getattr(pl_model, "std")) - # copy pasting `_postprocess` function from `fastface.FaceDetector` to `torch.nn.Module` - # set postprocess function + # copy pasting `_output_transform` function from `fastface.FaceDetector` to `torch.nn.Module` + # set output_transform function # this is called from FaceDetector lightning module form fastface itself # https://github.com/borhanMorphy/fastface/blob/master/fastface/module.py#L200 - setattr(model, "_postprocess", getattr(pl_model, "_postprocess")) + setattr(model, "_output_transform", getattr(pl_model, "_output_transform")) return model diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 5623974495..15d70d7c95 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -16,7 +16,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Postprocess, Preprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import Preprocess from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -70,7 +71,7 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]: return default_transforms(self.image_size) -class InstanceSegmentationPostProcess(Postprocess): +class InstanceSegmentationOutputTransform(OutputTransform): @staticmethod def uncollate(batch: Any) -> Any: return batch[DefaultDataKeys.PREDS] @@ -79,7 +80,7 @@ def uncollate(batch: Any) -> Any: class InstanceSegmentationData(DataModule): preprocess_cls = InstanceSegmentationPreprocess - postprocess_cls = InstanceSegmentationPostProcess + output_transform_cls = InstanceSegmentationOutputTransform @classmethod def from_coco( diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index c947a8c4f6..bd86e655f2 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -21,7 +21,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS -from flash.image.instance_segmentation.data import InstanceSegmentationPostProcess, InstanceSegmentationPreprocess +from flash.image.instance_segmentation.data import InstanceSegmentationOutputTransform, InstanceSegmentationPreprocess class InstanceSegmentation(AdapterTask): @@ -99,5 +99,5 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: "If you'd like to change this, extend the InstanceSegmentation Task and override `on_load_checkpoint`." ) self.data_pipeline = DataPipeline( - preprocess=InstanceSegmentationPreprocess(), postprocess=InstanceSegmentationPostProcess() + preprocess=InstanceSegmentationPreprocess(), output_transform=InstanceSegmentationOutputTransform() ) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 17f96a0fed..d00ee0e8d5 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -20,7 +20,7 @@ from flash.core.classification import ClassificationTask from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Postprocess +from flash.core.data.io.output_transform import OutputTransform from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _KORNIA_AVAILABLE from flash.core.utilities.isinstance import _isinstance @@ -29,8 +29,8 @@ LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, + OUTPUT_TRANSFORM_TYPE, OUTPUT_TYPE, - POSTPROCESS_TYPE, ) from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS @@ -40,7 +40,7 @@ import kornia as K -class SemanticSegmentationPostprocess(Postprocess): +class SemanticSegmentationOutputTransform(OutputTransform): def per_sample_transform(self, sample: Any) -> Any: resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation="bilinear") sample[DefaultDataKeys.PREDS] = resize(sample[DefaultDataKeys.PREDS]) @@ -69,10 +69,10 @@ class SemanticSegmentation(ClassificationTask): learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. - postprocess: :class:`~flash.core.data.process.Postprocess` use for post processing samples. + output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` use for post processing samples. """ - postprocess_cls = SemanticSegmentationPostprocess + output_transform_cls = SemanticSegmentationOutputTransform backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES @@ -95,7 +95,7 @@ def __init__( learning_rate: float = 1e-3, multi_label: bool = False, output: OUTPUT_TYPE = None, - postprocess: POSTPROCESS_TYPE = None, + output_transform: OUTPUT_TRANSFORM_TYPE = None, ) -> None: if metrics is None: metrics = IoU(num_classes=num_classes) @@ -115,7 +115,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, output=output or SegmentationLabels(), - postprocess=postprocess or self.postprocess_cls(), + output_transform=output_transform or self.output_transform_cls(), ) self.save_hyperparameters() diff --git a/flash/tabular/data.py b/flash/tabular/data.py index b078344366..6c981019bd 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -22,7 +22,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import Deserializer, Preprocess from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.classification.utils import ( _compute_normalization, @@ -234,7 +235,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "Pr return cls(**state_dict) -class TabularPostprocess(Postprocess): +class TabularOutputTransform(OutputTransform): def uncollate(self, batch: Any) -> Any: return batch @@ -243,7 +244,7 @@ class TabularData(DataModule): """Data module for tabular tasks.""" preprocess_cls = TabularPreprocess - postprocess_cls = TabularPostprocess + output_transform_cls = TabularOutputTransform is_regression: bool = False diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 1564277199..d64e9026c0 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -24,7 +24,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, LabelsState -from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import Deserializer, Preprocess from flash.core.integrations.labelstudio.data_source import LabelStudioTextClassificationDataSource from flash.core.utilities.imports import _TEXT_AVAILABLE, requires @@ -312,7 +313,7 @@ def collate(self, samples: Any) -> Tensor: return default_data_collator(samples) -class TextClassificationPostprocess(Postprocess): +class TextClassificationOutputTransform(OutputTransform): def per_batch_transform(self, batch: Any) -> Any: if isinstance(batch, SequenceClassifierOutput): batch = batch.logits @@ -323,7 +324,7 @@ class TextClassificationData(DataModule): """Data Module for text classification tasks.""" preprocess_cls = TextClassificationPreprocess - postprocess_cls = TextClassificationPostprocess + output_transform_cls = TextClassificationOutputTransform @property def backbone(self) -> Optional[str]: diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 568bbf7d4e..c0caf41efc 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -30,7 +30,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Postprocess, Preprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import Preprocess from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _TEXT_AVAILABLE, requires from flash.core.utilities.stages import RunningStage @@ -589,7 +590,7 @@ def collate(self, samples: Any) -> Tensor: return default_data_collator(samples) -class QuestionAnsweringPostprocess(Postprocess): +class QuestionAnsweringOutputTransform(OutputTransform): @requires("text") def __init__(self): super().__init__() @@ -638,7 +639,7 @@ class QuestionAnsweringData(DataModule): """Data module for QuestionAnswering task.""" preprocess_cls = QuestionAnsweringPreprocess - postprocess_cls = QuestionAnsweringPostprocess + output_transform_cls = QuestionAnsweringOutputTransform @classmethod def from_squad_v2( diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 5b256ce004..d523af6692 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -21,7 +21,8 @@ import flash from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataSources -from flash.core.data.process import Postprocess, Preprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import Preprocess from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _TEXT_AVAILABLE, requires from flash.text.classification.data import TextDeserializer @@ -319,7 +320,7 @@ def collate(self, samples: Any) -> Tensor: return default_data_collator(samples) -class Seq2SeqPostprocess(Postprocess): +class Seq2SeqOutputTransform(OutputTransform): @requires("text") def __init__(self): super().__init__() @@ -360,4 +361,4 @@ class Seq2SeqData(DataModule): """Data module for Seq2Seq tasks.""" preprocess_cls = Seq2SeqPreprocess - postprocess_cls = Seq2SeqPostprocess + output_transform_cls = Seq2SeqOutputTransform diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 3c585f5a66..cd99caa490 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Callable, Dict, Optional, Union -from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqOutputTransform, Seq2SeqPreprocess class SummarizationPreprocess(Seq2SeqPreprocess): @@ -45,4 +45,4 @@ def __init__( class SummarizationData(Seq2SeqData): preprocess_cls = SummarizationPreprocess - postprocess_cls = Seq2SeqPostprocess + output_transform_cls = Seq2SeqOutputTransform diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 99bf064ad7..261de7a03a 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -83,7 +83,7 @@ def task(self) -> str: def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: str) -> None: tgt_lns = self.tokenize_labels(batch["labels"]) - result = self.rouge(self._postprocess.uncollate(generated_tokens), tgt_lns) + result = self.rouge(self._output_transform.uncollate(generated_tokens), tgt_lns) self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True) @staticmethod diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 0c86022a4d..89a712d492 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Callable, Dict, Optional, Union -from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqOutputTransform, Seq2SeqPreprocess class TranslationPreprocess(Seq2SeqPreprocess): @@ -46,4 +46,4 @@ class TranslationData(Seq2SeqData): """Data module for Translation tasks.""" preprocess_cls = TranslationPreprocess - postprocess_cls = Seq2SeqPostprocess + output_transform_cls = Seq2SeqOutputTransform diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index f47d4f6b08..d57f7775df 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -83,7 +83,7 @@ def compute_metrics(self, generated_tokens, batch, prefix): tgt_lns = self.tokenize_labels(batch["labels"]) # wrap targets in list as score expects a list of potential references tgt_lns = [[reference] for reference in tgt_lns] - result = self.bleu(self._postprocess.uncollate(generated_tokens), tgt_lns) + result = self.bleu(self._output_transform.uncollate(generated_tokens), tgt_lns) self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True) @staticmethod diff --git a/flash_notebooks/image_classification.ipynb b/flash_notebooks/image_classification.ipynb index 342bf72c74..4c5766ea7d 100644 --- a/flash_notebooks/image_classification.ipynb +++ b/flash_notebooks/image_classification.ipynb @@ -267,7 +267,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = ImageClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/image_classification_model.pt\")" + "model = ImageClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt\")" ] }, { diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index 91419baec6..e801c02224 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -22,7 +22,7 @@ from flash import Trainer from flash.__main__ import main from flash.audio import SpeechRecognition -from flash.audio.speech_recognition.data import SpeechRecognitionPostprocess, SpeechRecognitionPreprocess +from flash.audio.speech_recognition.data import SpeechRecognitionOutputTransform, SpeechRecognitionPreprocess from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _AUDIO_AVAILABLE from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING @@ -79,9 +79,9 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SpeechRecognition(backbone=TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and postprocess have been attached + # TODO: Currently only servable once a preprocess and output_transform have been attached model._preprocess = SpeechRecognitionPreprocess() - model._postprocess = SpeechRecognitionPostprocess() + model._output_transform = SpeechRecognitionOutputTransform() model.eval() model.serve() diff --git a/tests/core/data/io/test_output_transform.py b/tests/core/data/io/test_output_transform.py new file mode 100644 index 0000000000..5fb2b17d69 --- /dev/null +++ b/tests/core/data/io/test_output_transform.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from flash.core.data.batch import default_uncollate +from flash.core.data.io.output_transform import _OutputTransformProcessor + + +def test_output_transform_processor_str(): + output_transform_processor = _OutputTransformProcessor( + default_uncollate, + torch.relu, + torch.softmax, + None, + ) + assert str(output_transform_processor) == ( + "_OutputTransformProcessor:\n" + "\t(per_batch_transform): FuncModule(relu)\n" + "\t(uncollate_fn): FuncModule(default_uncollate)\n" + "\t(per_sample_transform): FuncModule(softmax)\n" + "\t(output): None" + ) diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py index c14cf35fcc..958f31bf85 100644 --- a/tests/core/data/test_batch.py +++ b/tests/core/data/test_batch.py @@ -18,7 +18,7 @@ from torch.testing import assert_allclose from torch.utils.data._utils.collate import default_collate -from flash.core.data.batch import _Postprocessor, _Preprocessor, _Sequential, default_uncollate +from flash.core.data.batch import _Preprocessor, _Sequential, default_uncollate from flash.core.utilities.stages import RunningStage @@ -62,22 +62,6 @@ def test_preprocessor_str(): ) -def test_postprocessor_str(): - postprocessor = _Postprocessor( - default_uncollate, - torch.relu, - torch.softmax, - None, - ) - assert str(postprocessor) == ( - "_Postprocessor:\n" - "\t(per_batch_transform): FuncModule(relu)\n" - "\t(uncollate_fn): FuncModule(default_uncollate)\n" - "\t(per_sample_transform): FuncModule(softmax)\n" - "\t(output): None" - ) - - class TestDefaultUncollate: BATCH_SIZE = 3 diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index ce4df2e9fb..51c0279661 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -25,12 +25,13 @@ from flash import Trainer from flash.core.data.auto_dataset import IterableAutoDataset -from flash.core.data.batch import _Postprocessor, _Preprocessor +from flash.core.data.batch import _Preprocessor from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import _StageOrchestrator, DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource from flash.core.data.io.output import Output -from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess +from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform +from flash.core.data.process import DefaultPreprocess, Deserializer, Preprocess from flash.core.data.properties import ProcessState from flash.core.data.states import PerBatchTransformOnDevice, ToTensorTransform from flash.core.model import Task @@ -73,23 +74,23 @@ def test_data_pipeline_str(): data_pipeline = DataPipeline( data_source=cast(DataSource, "data_source"), preprocess=cast(Preprocess, "preprocess"), - postprocess=cast(Postprocess, "postprocess"), + output_transform=cast(OutputTransform, "output_transform"), output=cast(Output, "output"), deserializer=cast(Deserializer, "deserializer"), ) expected = "data_source=data_source, deserializer=deserializer, " - expected += "preprocess=preprocess, postprocess=postprocess, output=output" + expected += "preprocess=preprocess, output_transform=output_transform, output=output" assert str(data_pipeline) == (f"DataPipeline({expected})") @pytest.mark.parametrize("use_preprocess", [False, True]) -@pytest.mark.parametrize("use_postprocess", [False, True]) -def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): +@pytest.mark.parametrize("use_output_transform", [False, True]) +def test_data_pipeline_init_and_assignement(use_preprocess, use_output_transform, tmpdir): class CustomModel(Task): - def __init__(self, postprocess: Optional[Postprocess] = None): + def __init__(self, output_transform: Optional[OutputTransform] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - self._postprocess = postprocess + self._output_transform = output_transform def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) @@ -97,17 +98,17 @@ def train_dataloader(self) -> Any: class SubPreprocess(DefaultPreprocess): pass - class SubPostprocess(Postprocess): + class SubOutputTransform(OutputTransform): pass data_pipeline = DataPipeline( preprocess=SubPreprocess() if use_preprocess else None, - postprocess=SubPostprocess() if use_postprocess else None, + output_transform=SubOutputTransform() if use_output_transform else None, ) assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else DefaultPreprocess) - assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) + assert isinstance(data_pipeline._output_transform, SubOutputTransform if use_output_transform else OutputTransform) - model = CustomModel(postprocess=Postprocess()) + model = CustomModel(output_transform=OutputTransform()) model.data_pipeline = data_pipeline # TODO: the line below should make the same effect but it's not # data_pipeline._attach_to_model(model) @@ -117,10 +118,10 @@ class SubPostprocess(Postprocess): else: assert model._preprocess is None or isinstance(model._preprocess, Preprocess) - if use_postprocess: - assert isinstance(model._postprocess, SubPostprocess) + if use_output_transform: + assert isinstance(model._output_transform, SubOutputTransform) else: - assert model._postprocess is None or isinstance(model._postprocess, Postprocess) + assert model._output_transform is None or isinstance(model._output_transform, OutputTransform) def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): @@ -294,9 +295,9 @@ def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): def test_detach_preprocessing_from_model(tmpdir): class CustomModel(Task): - def __init__(self, postprocess: Optional[Postprocess] = None): + def __init__(self, output_transform: Optional[OutputTransform] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - self._postprocess = postprocess + self._output_transform = output_transform def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) @@ -355,7 +356,7 @@ class SubPreprocess(DefaultPreprocess): class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - self._postprocess = Postprocess() + self._output_transform = OutputTransform() def training_step(self, batch: Any, batch_idx: int) -> Any: pass @@ -461,7 +462,7 @@ def on_predict_dataloader(self) -> None: assert isinstance(self.predict_step, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) self._assert_stage_orchestrator_state( - self.predict_step._stage_mapping, current_running_stage, cls=_Postprocessor + self.predict_step._stage_mapping, current_running_stage, cls=_OutputTransformProcessor ) def on_fit_end(self) -> None: @@ -494,16 +495,20 @@ def test_stage_orchestrator_state_attach_detach(tmpdir): _original_predict_step = model.predict_step class CustomDataPipeline(DataPipeline): - def _attach_postprocess_to_model(self, model: "Task", _postprocesssor: _Postprocessor) -> "Task": - model.predict_step = self._model_predict_step_wrapper(model.predict_step, _postprocesssor, model) + def _attach_output_transform_to_model( + self, model: "Task", _output_transform_processor: _OutputTransformProcessor + ) -> "Task": + model.predict_step = self._model_predict_step_wrapper( + model.predict_step, _output_transform_processor, model + ) return model data_pipeline = CustomDataPipeline(preprocess=preprocess) - _postprocesssor = data_pipeline._create_uncollate_postprocessors(RunningStage.PREDICTING) - data_pipeline._attach_postprocess_to_model(model, _postprocesssor) + _output_transform_processor = data_pipeline._create_output_transform_processor(RunningStage.PREDICTING) + data_pipeline._attach_output_transform_to_model(model, _output_transform_processor) assert model.predict_step._original == _original_predict_step - assert model.predict_step._stage_mapping[RunningStage.PREDICTING] == _postprocesssor - data_pipeline._detach_postprocess_from_model(model) + assert model.predict_step._stage_mapping[RunningStage.PREDICTING] == _output_transform_processor + data_pipeline._detach_output_transform_from_model(model) assert model.predict_step == _original_predict_step diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 9a465c3e10..d6b9f96999 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -35,7 +35,8 @@ from flash.audio import SpeechRecognition from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask -from flash.core.data.process import DefaultPreprocess, Postprocess +from flash.core.data.io.output_transform import OutputTransform +from flash.core.data.process import DefaultPreprocess from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image from flash.image import ImageClassificationData, ImageClassifier, SemanticSegmentation from flash.tabular import TabularClassifier @@ -64,7 +65,7 @@ def __getitem__(self, index: int) -> Tensor: return torch.rand(1, 28, 28) -class DummyPostprocess(Postprocess): +class DummyOutputTransform(OutputTransform): pass @@ -222,10 +223,10 @@ def test_classification_task_trainer_predict(tmpdir): def test_task_datapipeline_save(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) - task = ClassificationTask(model, loss_fn=F.nll_loss, postprocess=DummyPostprocess()) + task = ClassificationTask(model, loss_fn=F.nll_loss, output_transform=DummyOutputTransform()) # to check later - task.postprocess.test = True + task.output_transform.test = True # generate a checkpoint trainer = pl.Trainer( @@ -242,7 +243,7 @@ def test_task_datapipeline_save(tmpdir): # load from file task = ClassificationTask.load_from_checkpoint(path, model=model) - assert task.postprocess.test + assert task.output_transform.test @pytest.mark.parametrize( diff --git a/tests/image/face_detection/test_model.py b/tests/image/face_detection/test_model.py index 05228c6586..826b9fe64f 100644 --- a/tests/image/face_detection/test_model.py +++ b/tests/image/face_detection/test_model.py @@ -46,7 +46,7 @@ def test_fastface_forward(): model = FaceDetector(model="lffd_slim") mock_batch = torch.randn(2, 3, 256, 256) - # test model forward (tests: _prepare_batch, logits_to_preds, _postprocess from ff) + # test model forward (tests: _prepare_batch, logits_to_preds, _output_transform from ff) model(mock_batch) diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index bf7609dd17..ac7c005105 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -23,7 +23,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassifier -from flash.text.classification.data import TextClassificationPostprocess, TextClassificationPreprocess +from flash.text.classification.data import TextClassificationOutputTransform, TextClassificationPreprocess from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -77,9 +77,9 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TextClassifier(2, TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and postprocess have been attached + # TODO: Currently only servable once a preprocess and output_transform have been attached model._preprocess = TextClassificationPreprocess(backbone=TEST_BACKBONE) - model._postprocess = TextClassificationPostprocess() + model._output_transform = TextClassificationOutputTransform() model.eval() model.serve() diff --git a/tests/text/question_answering/test_data.py b/tests/text/question_answering/test_data.py index 6a6893babf..a57bf4f2dd 100644 --- a/tests/text/question_answering/test_data.py +++ b/tests/text/question_answering/test_data.py @@ -157,9 +157,9 @@ def test_from_files(tmpdir): @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_postprocess_tokenizer(tmpdir): - """Tests that the tokenizer property in ``QuestionAnsweringPostprocess`` resolves correctly when a different - backbone is used.""" +def test_output_transform_tokenizer(tmpdir): + """Tests that the tokenizer property in ``QuestionAnsweringOutputTransform`` resolves correctly when a + different backbone is used.""" backbone = "allenai/longformer-base-4096" json_path = json_data(tmpdir, TEST_JSON_DATA) dm = QuestionAnsweringData.from_json( @@ -172,8 +172,8 @@ def test_postprocess_tokenizer(tmpdir): ) pipeline = dm.data_pipeline pipeline.initialize() - assert pipeline._postprocess_pipeline.backbone == backbone - assert pipeline._postprocess_pipeline.tokenizer is not None + assert pipeline._output_transform.backbone == backbone + assert pipeline._output_transform.tokenizer is not None @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") diff --git a/tests/text/seq2seq/core/test_data.py b/tests/text/seq2seq/core/test_data.py index d52bd9132a..8b9e2f862c 100644 --- a/tests/text/seq2seq/core/test_data.py +++ b/tests/text/seq2seq/core/test_data.py @@ -22,7 +22,7 @@ Seq2SeqDataSource, Seq2SeqFileDataSource, Seq2SeqJSONDataSource, - Seq2SeqPostprocess, + Seq2SeqOutputTransform, Seq2SeqSentencesDataSource, ) from tests.helpers.utils import _TEXT_TESTING @@ -41,7 +41,7 @@ (Seq2SeqCSVDataSource, {"backbone": "sshleifer/tiny-mbart"}), (Seq2SeqJSONDataSource, {"backbone": "sshleifer/tiny-mbart"}), (Seq2SeqSentencesDataSource, {"backbone": "sshleifer/tiny-mbart"}), - (Seq2SeqPostprocess, {}), + (Seq2SeqOutputTransform, {}), ], ) def test_tokenizer_state(cls, kwargs): diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py index 43f75cf54a..e29e691791 100644 --- a/tests/text/seq2seq/summarization/test_data.py +++ b/tests/text/seq2seq/summarization/test_data.py @@ -95,8 +95,8 @@ def test_from_files(tmpdir): @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_postprocess_tokenizer(tmpdir): - """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different +def test_output_transform_tokenizer(tmpdir): + """Tests that the tokenizer property in ``SummarizationOutputTransform`` resolves correctly when a different backbone is used.""" backbone = "sshleifer/bart-tiny-random" csv_path = csv_data(tmpdir) @@ -105,8 +105,8 @@ def test_postprocess_tokenizer(tmpdir): ) pipeline = dm.data_pipeline pipeline.initialize() - assert pipeline._postprocess_pipeline.backbone_state.backbone == backbone - assert pipeline._postprocess_pipeline.tokenizer is not None + assert pipeline._output_transform.backbone_state.backbone == backbone + assert pipeline._output_transform.tokenizer is not None @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index c6adf69fdc..b6e02e9895 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -21,7 +21,7 @@ from flash import Trainer from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import SummarizationTask -from flash.text.seq2seq.core.data import Seq2SeqPostprocess +from flash.text.seq2seq.core.data import Seq2SeqOutputTransform from flash.text.seq2seq.summarization.data import SummarizationPreprocess from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING @@ -78,9 +78,9 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SummarizationTask(TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and postprocess have been attached + # TODO: Currently only servable once a preprocess and output_transform have been attached model._preprocess = SummarizationPreprocess(backbone=TEST_BACKBONE) - model._postprocess = Seq2SeqPostprocess() + model._output_transform = Seq2SeqOutputTransform() model.eval() model.serve() diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index 237fa3bb5a..e552b74385 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -21,7 +21,7 @@ from flash import Trainer from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TranslationTask -from flash.text.seq2seq.core.data import Seq2SeqPostprocess +from flash.text.seq2seq.core.data import Seq2SeqOutputTransform from flash.text.seq2seq.translation.data import TranslationPreprocess from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING @@ -78,9 +78,9 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TranslationTask(TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and postprocess have been attached + # TODO: Currently only servable once a preprocess and output_transform have been attached model._preprocess = TranslationPreprocess(backbone=TEST_BACKBONE) - model._postprocess = Seq2SeqPostprocess() + model._output_transform = Seq2SeqOutputTransform() model.eval() model.serve()