Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Data Pipeline V2: Rename all Output implementations to end in Output #1011

Merged
merged 18 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874))

- Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels`, `Classes`, `FiftyOneLabels`, `Labels`, `Logits`, `Probabilities` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011))

### Deprecated

- Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927))
Expand Down
2 changes: 1 addition & 1 deletion _notebooks
10 changes: 5 additions & 5 deletions docs/source/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ _________________________
:nosignatures:
:template: classtemplate.rst

~flash.core.classification.Classes
~flash.core.classification.ClassesOutput
~flash.core.classification.ClassificationOutput
~flash.core.classification.ClassificationTask
~flash.core.classification.FiftyOneLabels
~flash.core.classification.Labels
~flash.core.classification.Logits
~flash.core.classification.FiftyOneLabelsOutput
~flash.core.classification.LabelsOutput
~flash.core.classification.LogitsOutput
~flash.core.classification.PredsClassificationOutput
~flash.core.classification.Probabilities
~flash.core.classification.ProbabilitiesOutput

flash.core.finetuning
_____________________
Expand Down
6 changes: 3 additions & 3 deletions docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ________________

detection.data.FiftyOneParser
detection.data.ObjectDetectionFiftyOneInput
detection.output.FiftyOneDetectionLabels
detection.output.FiftyOneDetectionLabelsOutput
detection.data.ObjectDetectionInputTransform

Keypoint Detection
Expand Down Expand Up @@ -103,8 +103,8 @@ ____________
segmentation.data.SemanticSegmentationFiftyOneInput
segmentation.data.SemanticSegmentationDeserializer
segmentation.model.SemanticSegmentationOutputTransform
segmentation.output.FiftyOneSegmentationLabels
segmentation.output.SegmentationLabels
segmentation.output.FiftyOneSegmentationLabelsOutput
segmentation.output.SegmentationLabelsOutput

.. autosummary::
:toctree: generated/
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/finetuning_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Here's an example of finetuning.
from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

Expand Down Expand Up @@ -56,7 +56,7 @@ Once you've finetuned, use the model to predict:
.. testcode:: finetune

# Output predictions as labels, automatically inferred from the training data in part 2.
model.output = Labels()
model.output = LabelsOutput()

predictions = model.predict(
[
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/training_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Here's an example:
from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

Expand Down
4 changes: 2 additions & 2 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ reference below).

.. code-block:: python

from flash.core.classification import Probabilities
from flash.core.classification import ProbabilitiesOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassifier

Expand All @@ -78,7 +78,7 @@ reference below).
)

# 3. Attach the Output
model.output = Probabilities()
model.output = ProbabilitiesOutput()

# 4. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
Expand Down
6 changes: 3 additions & 3 deletions docs/source/integrations/fiftyone.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ You can visualize predictions for classification, object detection, and
semantic segmentation tasks. Doing so is as easy as updating your model to use
one of the following outputs:

* :class:`FiftyOneLabels(return_filepath=True)<flash.core.classification.FiftyOneLabels>`
* :class:`FiftyOneSegmentationLabels(return_filepath=True)<flash.image.segmentation.output.FiftyOneSegmentationLabels>`
* :class:`FiftyOneDetectionLabels(return_filepath=True)<flash.image.detection.output.FiftyOneDetectionLabels>`
* :class:`FiftyOneLabelsOutput(return_filepath=True)<flash.core.classification.FiftyOneLabelsOutput>`
* :class:`FiftyOneSegmentationLabelsOutput(return_filepath=True)<flash.image.segmentation.output.FiftyOneSegmentationLabelsOutput>`
* :class:`FiftyOneDetectionLabelsOutput(return_filepath=True)<flash.image.detection.output.FiftyOneDetectionLabelsOutput>`

The :func:`~flash.core.integrations.fiftyone.visualize` function then lets you visualize
your predictions in the
Expand Down
8 changes: 4 additions & 4 deletions docs/source/template/optional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ Specifically, it should include any formatting and transforms that should always
If you want to support different use cases that require different prediction formats, you should add some :class:`~flash.core.data.io.output.Output` implementations in an ``output.py`` file.

Some good examples are in `flash/core/classification.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/core/classification.py>`_.
Here's the :class:`~flash.core.classification.Classes` :class:`~flash.core.data.io.output.Output`:
Here's the :class:`~flash.core.classification.ClassesOutput` :class:`~flash.core.data.io.output.Output`:

.. literalinclude:: ../../../flash/core/classification.py
:language: python
:pyobject: Classes
:pyobject: ClassesOutput

Alternatively, here's the :class:`~flash.core.classification.Logits` :class:`~flash.core.data.io.output.Output`:
Alternatively, here's the :class:`~flash.core.classification.LogitsOutput` :class:`~flash.core.data.io.output.Output`:

.. literalinclude:: ../../../flash/core/classification.py
:language: python
:pyobject: Logits
:pyobject: LogitsOutput

Take a look at :ref:`predictions` to learn more.

Expand Down
33 changes: 25 additions & 8 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn

from flash.core.adapter import AdapterTask
from flash.core.data.io.classification_input import ClassificationState
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
output=output or Classes(multi_label=multi_label),
output=output or ClassesOutput(multi_label=multi_label),
**kwargs,
)

Expand All @@ -102,7 +102,7 @@ def __init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
output=output or Classes(multi_label=multi_label),
output=output or ClassesOutput(multi_label=multi_label),
**kwargs,
)

Expand Down Expand Up @@ -137,14 +137,14 @@ def transform(self, sample: Any) -> Any:
return sample


class Logits(PredsClassificationOutput):
class LogitsOutput(PredsClassificationOutput):
"""A :class:`.Output` which simply converts the model outputs (assumed to be logits) to a list."""

def transform(self, sample: Any) -> Any:
return super().transform(sample).tolist()


class Probabilities(PredsClassificationOutput):
class ProbabilitiesOutput(PredsClassificationOutput):
"""A :class:`.Output` which applies a softmax to the model outputs (assumed to be logits) and converts to a
list."""

Expand All @@ -155,7 +155,7 @@ def transform(self, sample: Any) -> Any:
return torch.softmax(sample, -1).tolist()


class Classes(PredsClassificationOutput):
class ClassesOutput(PredsClassificationOutput):
"""A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and
converts to a list.

Expand All @@ -181,7 +181,7 @@ def transform(self, sample: Any) -> Union[int, List[int]]:
return torch.argmax(sample, -1).tolist()


class Labels(Classes):
class LabelsOutput(ClassesOutput):
"""A :class:`.Output` which converts the model outputs (either logits or probabilities) to the label of the
argmax classification.

Expand Down Expand Up @@ -219,7 +219,7 @@ def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]:
return classes


class FiftyOneLabels(ClassificationOutput):
class FiftyOneLabelsOutput(ClassificationOutput):
"""A :class:`.Output` which converts the model outputs to FiftyOne classification format.

Args:
Expand Down Expand Up @@ -339,3 +339,20 @@ def transform(
filepath = sample[DataKeys.METADATA]["filepath"]
return {"filepath": filepath, "predictions": fo_predictions}
return fo_predictions


class Labels(LabelsOutput):
def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False, threshold: float = 0.5):
rank_zero_deprecation(
"`Labels` was deprecated in v0.6.0 and will be removed in v0.7.0." "Please use `LabelsOutput` instead."
)
super().__init__(labels=labels, multi_label=multi_label, threshold=threshold)


class Probabilities(ProbabilitiesOutput):
def __init__(self, multi_label: bool = False):
rank_zero_deprecation(
"`Probabilities` was deprecated in v0.6.0 and will be removed in v0.7.0."
"Please use `ProbabilitiesOutput` instead."
)
super().__init__(multi_label=multi_label)
2 changes: 1 addition & 1 deletion flash/core/data/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.core.data.io.output import Output


class Preds(Output):
class PredsOutput(Output):
"""A :class:`~flash.core.data.io.output.Output` which returns the "preds" from the model outputs."""

def transform(self, sample: Any) -> Union[int, List[int]]:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn

from flash.core.classification import ClassificationAdapterTask, Labels
from flash.core.classification import ClassificationAdapterTask, LabelsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.classification.adapters import TRAINING_STRATEGIES
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
optimizer=optimizer,
lr_scheduler=lr_scheduler,
multi_label=multi_label,
output=output or Labels(multi_label=multi_label),
output=output or LabelsOutput(multi_label=multi_label),
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.detection.backbones import OBJECT_DETECTION_HEADS
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/detection/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
fo = None


class FiftyOneDetectionLabels(Output):
class FiftyOneDetectionLabelsOutput(Output):
"""A :class:`.Output` which converts model outputs to FiftyOne detection format.

Args:
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(

def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]:
if DataKeys.METADATA not in sample:
raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabels output.")
raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabelsOutput output.")

labels = None
if self._labels is not None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/face_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import fastface as ff


class DetectionLabels(Output):
class DetectionLabelsOutput(Output):
"""A :class:`.Output` which extracts predictions from sample dict."""

def transform(self, sample: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or DetectionLabels(),
output=output or DetectionLabelsOutput(),
input_transform=input_transform or FaceDetectionInputTransform(),
)

Expand Down
4 changes: 2 additions & 2 deletions flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from flash.core.adapter import AdapterTask
from flash.core.data.data_pipeline import DataPipeline
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/keypoint_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
12 changes: 6 additions & 6 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from flash.core.utilities.stages import RunningStage
from flash.image.data import ImageDeserializer, IMG_EXTENSIONS
from flash.image.segmentation.output import SegmentationLabels
from flash.image.segmentation.output import SegmentationLabelsOutput
from flash.image.segmentation.transforms import default_transforms, predict_default_transforms, train_default_transforms

SampleCollection = None
Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(
self.image_size = image_size
self.num_classes = num_classes
if num_classes:
labels_map = labels_map or SegmentationLabels.create_random_labels_map(num_classes)
labels_map = labels_map or SegmentationLabelsOutput.create_random_labels_map(num_classes)

super().__init__(
train_transform=train_transform,
Expand Down Expand Up @@ -329,9 +329,9 @@ def from_input(

num_classes = input_transform_kwargs["num_classes"]

labels_map = getattr(input_transform_kwargs, "labels_map", None) or SegmentationLabels.create_random_labels_map(
num_classes
)
labels_map = getattr(
input_transform_kwargs, "labels_map", None
) or SegmentationLabelsOutput.create_random_labels_map(num_classes)

data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map)

Expand Down Expand Up @@ -494,7 +494,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str)
raise TypeError(f"Unknown data type. Got: {type(data)}.")
# convert images and labels to numpy and stack horizontally
image_vis: np.ndarray = self._to_numpy(image.byte())
label_tmp: torch.Tensor = SegmentationLabels.labels_to_image(label.squeeze().byte(), self.labels_map)
label_tmp: torch.Tensor = SegmentationLabelsOutput.labels_to_image(label.squeeze().byte(), self.labels_map)
label_vis: np.ndarray = self._to_numpy(label_tmp)
img_vis = np.hstack((image_vis, label_vis))
# send to visualiser
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
from flash.image.segmentation.output import SegmentationLabels
from flash.image.segmentation.output import SegmentationLabelsOutput

if _KORNIA_AVAILABLE:
import kornia as K
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
output=output or SegmentationLabels(),
output=output or SegmentationLabelsOutput(),
output_transform=output_transform or self.output_transform_cls(),
)

Expand Down
Loading