Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Rename Serializer to Output and move to flash.core.data.io.output #927

Merged
merged 22 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ _________________________
:template: classtemplate.rst

~flash.core.classification.Classes
~flash.core.classification.ClassificationSerializer
~flash.core.classification.ClassificationOutput
~flash.core.classification.ClassificationTask
~flash.core.classification.FiftyOneLabels
~flash.core.classification.Labels
~flash.core.classification.Logits
~flash.core.classification.PredsClassificationSerializer
~flash.core.classification.PredsClassificationOutput
~flash.core.classification.Probabilities

flash.core.finetuning
Expand Down
13 changes: 11 additions & 2 deletions docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ flash.core.data
:local:
:backlinks: top

flash.core.data.io.output
_________________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~flash.core.data.io.output.Output
~flash.core.data.io.output.OutputMapping

flash.core.data.auto_dataset
____________________________

Expand Down Expand Up @@ -114,8 +125,6 @@ _______________________
~flash.core.data.process.Deserializer
~flash.core.data.process.Postprocess
~flash.core.data.process.Preprocess
~flash.core.data.process.SerializerMapping
~flash.core.data.process.Serializer

flash.core.data.properties
__________________________
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/flash.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ flash
~flash.core.data.callback.FlashCallback
~flash.core.data.process.Preprocess
~flash.core.data.process.Postprocess
~flash.core.data.process.Serializer
~flash.core.data.io.output.Output
~flash.core.model.Task
~flash.core.trainer.Trainer
14 changes: 7 additions & 7 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Here are common terms you need to be familiar with:
* - :class:`~flash.core.data.data_module.DataModule`
- The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders.
* - :class:`~flash.core.data.data_pipeline.DataPipeline`
- The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.process.Serializer` objects.
- The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.io.output.Output` objects.
* - :class:`~flash.core.data.data_source.DataSource`
- The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names).
* - :class:`~flash.core.data.process.Preprocess`
Expand All @@ -37,8 +37,8 @@ Here are common terms you need to be familiar with:
* - :class:`~flash.core.data.process.Postprocess`
- The :class:`~flash.core.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic.
The :class:`~flash.core.data.process.Postprocess` hooks cover from model outputs to predictions export.
* - :class:`~flash.core.data.process.Serializer`
- The :class:`~flash.core.data.process.Serializer` provides a single :meth:`~flash.core.data.process.Serializer.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction.
* - :class:`~flash.core.data.io.output.Output`
- The :class:`~flash.core.data.io.output.Output` provides a single :meth:`~flash.core.data.io.output.Output.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction.


*******************************************
Expand All @@ -59,7 +59,7 @@ Usually, extra processing logic should be added to bridge the gap between traini

The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way.
The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` classes can be used to manage the preprocessing and postprocessing transforms.
The :class:`~flash.core.data.process.Serializer` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.).
The :class:`~flash.core.data.io.output.Output` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.).

By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms),
Flash gives the user much more granular control over their data processing flow.
Expand Down Expand Up @@ -383,18 +383,18 @@ Example::
predictions = lightning_module(data)


Postprocess and Serializer
Postprocess and Output
__________________________


Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`, the Flash
:class:`~flash.core.data.data_pipeline.DataPipeline` will execute the :class:`~flash.core.data.process.Postprocess` hooks and the
:class:`~flash.core.data.process.Serializer` behind the scenes.
:class:`~flash.core.data.io.output.Output` behind the scenes.

First, the :meth:`~flash.core.data.process.Postprocess.per_batch_transform` hooks will be applied on the batch predictions.
Then, the :meth:`~flash.core.data.process.Postprocess.uncollate` will split the batch into individual predictions.
Next, the :meth:`~flash.core.data.process.Postprocess.per_sample_transform` will be applied on each prediction.
Finally, the :meth:`~flash.core.data.process.Serializer.serialize` method will be called to serialize the predictions.
Finally, the :meth:`~flash.core.data.io.output.Output.serialize` method will be called to serialize the predictions.

.. note:: The transform can be applied either on device or ``CPU``.

Expand Down
4 changes: 2 additions & 2 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Predict on a csv file
Serializing predictions
=======================

To change how predictions are serialized you can attach a :class:`~flash.core.data.process.Serializer` to your
To change how predictions are serialized you can attach a :class:`~flash.core.data.io.output.Output` to your
:class:`~flash.core.model.Task`. For example, you can choose to serialize outputs as probabilities (for more options see the API
reference below).

Expand All @@ -71,7 +71,7 @@ reference below).
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3. Attach the Serializer
# 3. Attach the Output
model.serializer = Probabilities()
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

# 4. Predict whether the image contains an ant or a bee
Expand Down
6 changes: 3 additions & 3 deletions docs/source/template/optional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ Add output serializers to your Task

We recommend that you do most of the heavy lifting in the :class:`~flash.core.data.process.Postprocess`.
Specifically, it should include any formatting and transforms that should always be applied to the predictions.
If you want to support different use cases that require different prediction formats, you should add some :class:`~flash.core.data.process.Serializer` implementations in a ``serialization.py`` file.
If you want to support different use cases that require different prediction formats, you should add some :class:`~flash.core.data.io.output.Output` implementations in a ``serialization.py`` file.
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

Some good examples are in `flash/core/classification.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/core/classification.py>`_.
Here's the :class:`~flash.core.classification.Classes` :class:`~flash.core.data.process.Serializer`:
Here's the :class:`~flash.core.classification.Classes` :class:`~flash.core.data.io.output.Output`:

.. literalinclude:: ../../../flash/core/classification.py
:language: python
:pyobject: Classes

Alternatively, here's the :class:`~flash.core.classification.Logits` :class:`~flash.core.data.process.Serializer`:
Alternatively, here's the :class:`~flash.core.classification.Logits` :class:`~flash.core.data.io.output.Output`:

.. literalinclude:: ../../../flash/core/classification.py
:language: python
Expand Down
7 changes: 4 additions & 3 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from flash.core.data.data_source import DataSource
from flash.core.data.datasets import FlashDataset, FlashIterableDataset
from flash.core.data.input_transform import InputTransform
from flash.core.data.process import Postprocess, Preprocess, Serializer
from flash.core.data.io.output import Output
from flash.core.data.process import Postprocess, Preprocess
from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402

Expand All @@ -44,10 +45,10 @@
"FlashCallback",
"FlashDataset",
"FlashIterableDataset",
"Preprocess",
"InputTransform",
"Output",
"Postprocess",
"Serializer",
"Preprocess",
"Task",
"Trainer",
]
8 changes: 4 additions & 4 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2Processor
Expand All @@ -41,7 +41,7 @@ class SpeechRecognition(Task):
learning_rate: Learning rate to use for training, defaults to ``1e-5``.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
"""

backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES
Expand All @@ -54,7 +54,7 @@ def __init__(
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: float = 1e-5,
serializer: SERIALIZER_TYPE = None,
output: OUTPUT_TYPE = None,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
Expand All @@ -68,7 +68,7 @@ def __init__(
optimizer=optimizer,
lr_scheduler=lr_scheduler,
learning_rate=learning_rate,
serializer=serializer,
output=output,
)

self.save_hyperparameters()
Expand Down
58 changes: 29 additions & 29 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from flash.core.adapter import AdapterTask
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.data.process import Serializer
from flash.core.data.io.output import Output
from flash.core.model import Task
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires

Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
output: Optional[Union[Output, Mapping[str, Output]]] = None,
**kwargs,
) -> None:

Expand All @@ -78,7 +78,7 @@ def __init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer or Classes(multi_label=multi_label),
output=output or Classes(multi_label=multi_label),
**kwargs,
)

Expand All @@ -91,7 +91,7 @@ def __init__(
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
output: Optional[Union[Output, Mapping[str, Output]]] = None,
**kwargs,
) -> None:

Expand All @@ -101,13 +101,13 @@ def __init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer or Classes(multi_label=multi_label),
output=output or Classes(multi_label=multi_label),
**kwargs,
)


class ClassificationSerializer(Serializer):
"""A base class for classification serializers.
class ClassificationOutput(Output):
"""A base class for classification outputs.

Args:
multi_label: If true, treats outputs as multi label logits.
Expand All @@ -123,39 +123,39 @@ def multi_label(self) -> bool:
return self._mutli_label


class PredsClassificationSerializer(ClassificationSerializer):
"""A :class:`~flash.core.classification.ClassificationSerializer` which gets the
class PredsClassificationOutput(ClassificationOutput):
"""A :class:`~flash.core.classification.ClassificationOutput` which gets the
:attr:`~flash.core.data.data_source.DefaultDataKeys.PREDS` from the sample.
"""

def serialize(self, sample: Any) -> Any:
def transform(self, sample: Any) -> Any:
if isinstance(sample, Mapping) and DefaultDataKeys.PREDS in sample:
sample = sample[DefaultDataKeys.PREDS]
if not isinstance(sample, torch.Tensor):
sample = torch.tensor(sample)
return sample


class Logits(PredsClassificationSerializer):
"""A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""
class Logits(PredsClassificationOutput):
"""A :class:`.Output` which simply converts the model outputs (assumed to be logits) to a list."""

def serialize(self, sample: Any) -> Any:
return super().serialize(sample).tolist()
def transform(self, sample: Any) -> Any:
return super().transform(sample).tolist()


class Probabilities(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies a softmax to the model outputs (assumed to be logits) and converts to a
class Probabilities(PredsClassificationOutput):
"""A :class:`.Output` which applies a softmax to the model outputs (assumed to be logits) and converts to a
list."""

def serialize(self, sample: Any) -> Any:
sample = super().serialize(sample)
def transform(self, sample: Any) -> Any:
sample = super().transform(sample)
if self.multi_label:
return torch.sigmoid(sample).tolist()
return torch.softmax(sample, -1).tolist()


class Classes(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and
class Classes(PredsClassificationOutput):
"""A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and
converts to a list.

Args:
Expand All @@ -168,8 +168,8 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5):

self.threshold = threshold

def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = super().serialize(sample)
def transform(self, sample: Any) -> Union[int, List[int]]:
sample = super().transform(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
Expand All @@ -181,7 +181,7 @@ def serialize(self, sample: Any) -> Union[int, List[int]]:


class Labels(Classes):
"""A :class:`.Serializer` which converts the model outputs (either logits or probabilities) to the label of the
"""A :class:`.Output` which converts the model outputs (either logits or probabilities) to the label of the
argmax classification.

Args:
Expand All @@ -198,7 +198,7 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
if labels is not None:
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]:
labels = None

if self._labels is not None:
Expand All @@ -208,18 +208,18 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
if state is not None:
labels = state.labels

classes = super().serialize(sample)
classes = super().transform(sample)

if labels is not None:
if self.multi_label:
return [labels[cls] for cls in classes]
return labels[classes]
rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning)
rank_zero_warn("No LabelsState was found, this output will act as a Classes output.", UserWarning)
return classes


class FiftyOneLabels(ClassificationSerializer):
"""A :class:`.Serializer` which converts the model outputs to FiftyOne classification format.
class FiftyOneLabels(ClassificationOutput):
"""A :class:`.Output` which converts the model outputs to FiftyOne classification format.

Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
Expand Down Expand Up @@ -254,7 +254,7 @@ def __init__(
if labels is not None:
self.set_state(LabelsState(labels))

def serialize(
def transform(
self,
sample: Any,
) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]:
Expand Down
Loading