Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into fix/load-rgba
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 8, 2021
2 parents f951fcb + 2260d53 commit 339bb67
Show file tree
Hide file tree
Showing 49 changed files with 440 additions and 386 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed classes named `*Serializer` and properties / variables named `serializer` to be `*Output` and `output` respectively ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927))

- Simplify loading RGBA images (drop alpha channel by default) ([#946](https://github.com/PyTorchLightning/lightning-flash/pull/946))
- Changed `Postprocess` to `OutputTransform` ([#942](https://github.com/PyTorchLightning/lightning-flash/pull/942))

- Changed loading of RGBA images to drop alpha channel by default ([#946](https://github.com/PyTorchLightning/lightning-flash/pull/946))

### Deprecated

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/audio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ __________________

speech_recognition.data.SpeechRecognitionPreprocess
speech_recognition.data.SpeechRecognitionBackboneState
speech_recognition.data.SpeechRecognitionPostprocess
speech_recognition.data.SpeechRecognitionOutputTransform
speech_recognition.data.SpeechRecognitionCSVDataSource
speech_recognition.data.SpeechRecognitionJSONDataSource
speech_recognition.data.BaseSpeechRecognition
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ _______________________
~flash.core.data.process.DefaultPreprocess
~flash.core.data.process.DeserializerMapping
~flash.core.data.process.Deserializer
~flash.core.data.process.Postprocess
~flash.core.data.io.output_transform.OutputTransform
~flash.core.data.process.Preprocess

flash.core.data.properties
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/flash.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ flash
~flash.core.data.data_module.DataModule
~flash.core.data.callback.FlashCallback
~flash.core.data.process.Preprocess
~flash.core.data.process.Postprocess
~flash.core.data.io.output_transform.OutputTransform
~flash.core.data.io.output.Output
~flash.core.model.Task
~flash.core.trainer.Trainer
2 changes: 1 addition & 1 deletion docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ ____________
segmentation.data.SemanticSegmentationPathsDataSource
segmentation.data.SemanticSegmentationFiftyOneDataSource
segmentation.data.SemanticSegmentationDeserializer
segmentation.model.SemanticSegmentationPostprocess
segmentation.model.SemanticSegmentationOutputTransform
segmentation.output.FiftyOneSegmentationLabels
segmentation.output.SegmentationLabels

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/tabular.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ __________________
~data.TabularCSVDataSource
~data.TabularDeserializer
~data.TabularPreprocess
~data.TabularPostprocess
~data.TabularOutputTransform
6 changes: 3 additions & 3 deletions docs/source/api/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ______________
~classification.model.TextClassifier
~classification.data.TextClassificationData

classification.data.TextClassificationPostprocess
classification.data.TextClassificationOutputTransform
classification.data.TextClassificationPreprocess
classification.data.TextDeserializer
classification.data.TextDataSource
Expand Down Expand Up @@ -48,7 +48,7 @@ __________________
question_answering.data.QuestionAnsweringDictionaryDataSource
question_answering.data.QuestionAnsweringFileDataSource
question_answering.data.QuestionAnsweringJSONDataSource
question_answering.data.QuestionAnsweringPostprocess
question_answering.data.QuestionAnsweringOutputTransform
question_answering.data.QuestionAnsweringPreprocess
question_answering.data.SQuADDataSource

Expand Down Expand Up @@ -96,7 +96,7 @@ _______________
seq2seq.core.data.Seq2SeqDataSource
seq2seq.core.data.Seq2SeqFileDataSource
seq2seq.core.data.Seq2SeqJSONDataSource
seq2seq.core.data.Seq2SeqPostprocess
seq2seq.core.data.Seq2SeqOutputTransform
seq2seq.core.data.Seq2SeqPreprocess
seq2seq.core.data.Seq2SeqSentencesDataSource
seq2seq.core.metrics.BLEUScore
Expand Down
28 changes: 14 additions & 14 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ Here are common terms you need to be familiar with:
* - :class:`~flash.core.data.data_module.DataModule`
- The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders.
* - :class:`~flash.core.data.data_pipeline.DataPipeline`
- The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.io.output.Output` objects.
- The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` objects.
* - :class:`~flash.core.data.data_source.DataSource`
- The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names).
* - :class:`~flash.core.data.process.Preprocess`
- The :class:`~flash.core.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic.
These hooks (such as :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device).
The :class:`~flash.core.data.data_pipeline.DataPipeline` contains a system to call the right hooks when needed.
The :class:`~flash.core.data.process.Preprocess` hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform).
* - :class:`~flash.core.data.process.Postprocess`
- The :class:`~flash.core.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic.
The :class:`~flash.core.data.process.Postprocess` hooks cover from model outputs to predictions export.
* - :class:`~flash.core.data.io.output_transform.OutputTransform`
- The :class:`~flash.core.data.io.output_transform.OutputTransform` provides a simple hook-based API to encapsulate your post-processing logic.
The :class:`~flash.core.data.io.output_transform.OutputTransform` hooks cover from model outputs to predictions export.
* - :class:`~flash.core.data.io.output.Output`
- The :class:`~flash.core.data.io.output.Output` provides a single :meth:`~flash.core.data.io.output.Output.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction.
- The :class:`~flash.core.data.io.output.Output` provides a single :meth:`~flash.core.data.io.output.Output.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.io.output_transform.OutputTransform`) to the desired output format during prediction.


*******************************************
Expand All @@ -58,8 +58,8 @@ However, after model training, it requires a lot of engineering overhead to make
Usually, extra processing logic should be added to bridge the gap between training data and raw data.

The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way.
The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` classes can be used to manage the preprocessing and postprocessing transforms.
The :class:`~flash.core.data.io.output.Output` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.).
The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.io.output_transform.OutputTransform` classes can be used to manage the preprocessing and postprocessing transforms.
The :class:`~flash.core.data.io.output.Output` class provides the logic for converting :class:`~flash.core.data.io.output_transform.OutputTransform` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.).

By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms),
Flash gives the user much more granular control over their data processing flow.
Expand All @@ -72,7 +72,7 @@ Here are the primary advantages:


To change the processing behavior only on specific stages for a given hook,
you can prefix each of the :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess`
you can prefix each of the :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.io.output_transform.OutputTransform`
hooks by adding ``train``, ``val``, ``test`` or ``predict``.

Check out :class:`~flash.core.data.process.Preprocess` for some examples.
Expand Down Expand Up @@ -383,17 +383,17 @@ Example::
predictions = lightning_module(data)


Postprocess and Output
OutputTransform and Output
__________________________


Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`, the Flash
:class:`~flash.core.data.data_pipeline.DataPipeline` will execute the :class:`~flash.core.data.process.Postprocess` hooks and the
:class:`~flash.core.data.data_pipeline.DataPipeline` will execute the :class:`~flash.core.data.io.output_transform.OutputTransform` hooks and the
:class:`~flash.core.data.io.output.Output` behind the scenes.

First, the :meth:`~flash.core.data.process.Postprocess.per_batch_transform` hooks will be applied on the batch predictions.
Then, the :meth:`~flash.core.data.process.Postprocess.uncollate` will split the batch into individual predictions.
Next, the :meth:`~flash.core.data.process.Postprocess.per_sample_transform` will be applied on each prediction.
First, the :meth:`~flash.core.data.io.output_transform.OutputTransform.per_batch_transform` hooks will be applied on the batch predictions.
Then, the :meth:`~flash.core.data.io.output_transform.OutputTransform.uncollate` will split the batch into individual predictions.
Next, the :meth:`~flash.core.data.io.output_transform.OutputTransform.per_sample_transform` will be applied on each prediction.
Finally, the :meth:`~flash.core.data.io.output.Output.serialize` method will be called to serialize the predictions.

.. note:: The transform can be applied either on device or ``CPU``.
Expand All @@ -402,7 +402,7 @@ Here is the pseudo-code:

Example::

# This will be wrapped into a :class:`~flash.core.data.batch._Postprocessor`
# This will be wrapped into a :class:`~flash.core.data.batch._OutputTransformProcessor`
def uncollate_fn(batch: Any) -> Any:

batch = per_batch_transform(batch)
Expand Down
20 changes: 10 additions & 10 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Inside `data.py <https://github.com/PyTorchLightning/lightning-flash/blob/master
#. a :class:`~flash.core.data.process.Preprocess`
#. a :class:`~flash.core.data.data_module.DataModule`
#. a :class:`~flash.core.data.base_viz.BaseVisualization` *(optional)*
#. a :class:`~flash.core.data.process.Postprocess` *(optional)*
#. a :class:`~flash.core.data.io.output_transform.OutputTransform` *(optional)*

DataSource
^^^^^^^^^^
Expand Down Expand Up @@ -196,19 +196,19 @@ We can configure our custom visualization in the ``TemplateData`` using :meth:`~
:dedent: 4
:pyobject: TemplateData.configure_data_fetcher

Postprocess
^^^^^^^^^^^
OutputTransform
^^^^^^^^^^^^^^^

:class:`~flash.core.data.process.Postprocess` contains any transforms that need to be applied *after* the model.
:class:`~flash.core.data.io.output_transform.OutputTransform` contains any transforms that need to be applied *after* the model.
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc.
As an example, here's the :class:`~text.classification.data.TextClassificationPostprocess` which gets the logits from a ``SequenceClassifierOutput``:
As an example, here's the :class:`~text.classification.data.TextClassificationOutputTransform` which gets the logits from a ``SequenceClassifierOutput``:

.. literalinclude:: ../../../flash/text/classification/data.py
:language: python
:pyobject: TextClassificationPostprocess
:pyobject: TextClassificationOutputTransform

In your :class:`~flash.core.data.data_source.DataSource` or :class:`~flash.core.data.process.Preprocess`, you can add metadata to the batch using the :attr:`~flash.core.data.data_source.DefaultDataKeys.METADATA` key.
Your :class:`~flash.core.data.process.Postprocess` can then use this metadata in its transforms.
Your :class:`~flash.core.data.io.output_transform.OutputTransform` can then use this metadata in its transforms.
You should use this approach if your postprocessing depends on the state of the input before the :class:`~flash.core.data.process.Preprocess` transforms.
For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the :attr:`~flash.core.data.data_source.DefaultDataKeys.METADATA`.
Here's an example from the :class:`~flash.image.segmentation.SemanticSegmentationNumpyDataSource`:
Expand All @@ -218,13 +218,13 @@ Here's an example from the :class:`~flash.image.segmentation.SemanticSegmentatio
:dedent: 4
:pyobject: SemanticSegmentationNumpyDataSource.load_sample

The :attr:`~flash.core.data.data_source.DefaultDataKeys.METADATA` can now be referenced in your :class:`~flash.core.data.process.Postprocess`.
For example, here's the code for the ``per_sample_transform`` method of the :class:`~flash.image.segmentation.model.SemanticSegmentationPostprocess`:
The :attr:`~flash.core.data.data_source.DefaultDataKeys.METADATA` can now be referenced in your :class:`~flash.core.data.io.output_transform.OutputTransform`.
For example, here's the code for the ``per_sample_transform`` method of the :class:`~flash.image.segmentation.model.SemanticSegmentationOutputTransform`:

.. literalinclude:: ../../../flash/image/segmentation/model.py
:language: python
:dedent: 4
:pyobject: SemanticSegmentationPostprocess.per_sample_transform
:pyobject: SemanticSegmentationOutputTransform.per_sample_transform

------

Expand Down
2 changes: 1 addition & 1 deletion docs/source/template/optional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Here's how we create our transforms in the :class:`~flash.image.classification.d
Add outputs to your Task
========================

We recommend that you do most of the heavy lifting in the :class:`~flash.core.data.process.Postprocess`.
We recommend that you do most of the heavy lifting in the :class:`~flash.core.data.io.output_transform.OutputTransform`.
Specifically, it should include any formatting and transforms that should always be applied to the predictions.
If you want to support different use cases that require different prediction formats, you should add some :class:`~flash.core.data.io.output.Output` implementations in an ``output.py`` file.

Expand Down
5 changes: 3 additions & 2 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from flash.core.data.datasets import FlashDataset, FlashIterableDataset
from flash.core.data.input_transform import InputTransform
from flash.core.data.io.output import Output
from flash.core.data.process import Postprocess, Preprocess, Serializer
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.process import Preprocess, Serializer
from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402

Expand All @@ -47,7 +48,7 @@
"FlashIterableDataset",
"InputTransform",
"Output",
"Postprocess",
"OutputTransform",
"Preprocess",
"Serializer",
"Task",
Expand Down
9 changes: 5 additions & 4 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
DefaultDataSources,
PathsDataSource,
)
from flash.core.data.process import Deserializer, Postprocess, Preprocess
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.process import Deserializer, Preprocess
from flash.core.data.properties import ProcessState
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires

Expand Down Expand Up @@ -190,13 +191,13 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
@dataclass(unsafe_hash=True, frozen=True)
class SpeechRecognitionBackboneState(ProcessState):
"""The ``SpeechRecognitionBackboneState`` stores the backbone in use by the
:class:`~flash.audio.speech_recognition.data.SpeechRecognitionPostprocess`
:class:`~flash.audio.speech_recognition.data.SpeechRecognitionOutputTransform`
"""

backbone: str


class SpeechRecognitionPostprocess(Postprocess):
class SpeechRecognitionOutputTransform(OutputTransform):
@requires("audio")
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -237,4 +238,4 @@ class SpeechRecognitionData(DataModule):
"""Data Module for text classification tasks."""

preprocess_cls = SpeechRecognitionPreprocess
postprocess_cls = SpeechRecognitionPostprocess
output_transform_cls = SpeechRecognitionOutputTransform
74 changes: 0 additions & 74 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,80 +244,6 @@ def __str__(self) -> str:
)


class _Postprocessor(torch.nn.Module):
"""This class is used to encapsultate the following functions of a Postprocess Object:
Inside main process:
per_batch_transform: Function to transform a batch
per_sample_transform: Function to transform an individual sample
uncollate_fn: Function to split a batch into samples
per_sample_transform: Function to transform an individual sample
save_fn: Function to save all data
save_per_sample: Function to save an individual sample
is_serving: Whether the Postprocessor is used in serving mode.
"""

def __init__(
self,
uncollate_fn: Callable,
per_batch_transform: Callable,
per_sample_transform: Callable,
output: Optional[Callable],
save_fn: Optional[Callable] = None,
save_per_sample: bool = False,
is_serving: bool = False,
):
super().__init__()
self.uncollate_fn = convert_to_modules(uncollate_fn)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.output = convert_to_modules(output)
self.save_fn = convert_to_modules(save_fn)
self.save_per_sample = convert_to_modules(save_per_sample)
self.is_serving = is_serving

@staticmethod
def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]:
metadata = None
if isinstance(batch, Mapping) and DefaultDataKeys.METADATA in batch:
metadata = batch.pop(DefaultDataKeys.METADATA, None)
return batch, metadata

def forward(self, batch: Sequence[Any]):
batch, metadata = self._extract_metadata(batch)
uncollated = self.uncollate_fn(self.per_batch_transform(batch))
if metadata:
for sample, sample_metadata in zip(uncollated, metadata):
sample[DefaultDataKeys.METADATA] = sample_metadata

final_preds = [self.per_sample_transform(sample) for sample in uncollated]

if self.output is not None:
final_preds = [self.output(sample) for sample in final_preds]

if isinstance(uncollated, Tensor) and isinstance(final_preds[0], Tensor):
final_preds = torch.stack(final_preds)
else:
final_preds = type(final_preds)(final_preds)

if self.save_fn:
if self.save_per_sample:
for pred in final_preds:
self.save_fn(pred)
else:
self.save_fn(final_preds)
return final_preds

def __str__(self) -> str:
return (
"_Postprocessor:\n"
f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
f"\t(uncollate_fn): {str(self.uncollate_fn)}\n"
f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
f"\t(output): {str(self.output)}"
)


def default_uncollate(batch: Any):
"""
This function is used to uncollate a batch into samples.
Expand Down
Loading

0 comments on commit 339bb67

Please sign in to comment.