Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Clean Output #939

Merged
merged 4 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))

- Removed `Output.enable` and `Output.disable` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))


## [0.5.2] - 2021-11-05

Expand Down
1 change: 0 additions & 1 deletion docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ _________________________
:template: classtemplate.rst

~flash.core.data.io.output.Output
~flash.core.data.io.output.OutputMapping

flash.core.data.auto_dataset
____________________________
Expand Down
2 changes: 1 addition & 1 deletion flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SpeechRecognition(Task):
learning_rate: Learning rate to use for training, defaults to ``1e-5``.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES
Expand Down
38 changes: 2 additions & 36 deletions flash/core/data/io/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping
from typing import Any

import torch

import flash
from flash.core.data.properties import Properties
from flash.core.data.utils import convert_to_modules

Expand All @@ -24,18 +23,6 @@ class Output(Properties):
"""An :class:`.Output` encapsulates a single :meth:`~flash.core.data.io.output.Output.transform` method which
is used to convert the model output into the desired output format when predicting."""

def __init__(self):
super().__init__()
self._is_enabled = True

def enable(self):
"""Enable output transformation."""
self._is_enabled = True

def disable(self):
"""Disable output transformation."""
self._is_enabled = False

@staticmethod
def transform(sample: Any) -> Any:
"""Convert the given sample into the desired output format.
Expand All @@ -49,28 +36,7 @@ def transform(sample: Any) -> Any:
return sample

def __call__(self, sample: Any) -> Any:
if self._is_enabled:
return self.transform(sample)
return sample


class OutputMapping(Output):
"""If the model output is a dictionary, then the :class:`.OutputMapping` enables each entry in the dictionary
to be passed to it's own :class:`.Output`."""

def __init__(self, outputs: Mapping[str, Output]):
super().__init__()

self._outputs = outputs

def transform(self, sample: Any) -> Any:
if isinstance(sample, Mapping):
return {key: output.transform(sample[key]) for key, output in self._outputs.items()}
raise ValueError("The model output must be a mapping when using an OutputMapping.")

def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"):
for output in self._outputs.values():
output.attach_data_pipeline_state(data_pipeline_state)
return self.transform(sample)


class _OutputProcessor(torch.nn.Module):
Expand Down
11 changes: 4 additions & 7 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
from flash.core.data.data_source import DataSource
from flash.core.data.io.output import Output, OutputMapping
from flash.core.data.io.output import Output
from flash.core.data.process import Deserializer, DeserializerMapping, Postprocess, Preprocess
from flash.core.data.properties import ProcessState
from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY
Expand Down Expand Up @@ -319,8 +319,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check
deserialize the input
preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task.
postprocess: :class:`~flash.core.data.process.Postprocess` to use as the default for this task.
output: Either a single :class:`~flash.core.data.io.output.Output` or a mapping of these to
serialize the output e.g. convert the model output into the desired output format when predicting.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY
Expand Down Expand Up @@ -630,9 +629,7 @@ def output(self) -> Optional[Output]:

@torch.jit.unused
@output.setter
def output(self, output: Union[Output, Mapping[str, Output]]):
if isinstance(output, Mapping):
output = OutputMapping(output)
def output(self, output: Output):
self._output = output

@torch.jit.unused
Expand Down Expand Up @@ -662,7 +659,7 @@ def serializer(self) -> Optional[Output]:
"It will be removed in v%(remove_in)s.",
stream=functools.partial(warn, category=FutureWarning),
)
def serializer(self, serializer: Union[Output, Mapping[str, Output]]):
def serializer(self, serializer: Output):
self.output = serializer

def build_data_pipeline(
Expand Down
2 changes: 1 addition & 1 deletion flash/core/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
DESERIALIZER_TYPE = Optional[Union[Deserializer, Mapping[str, Deserializer]]]
PREPROCESS_TYPE = Optional[Preprocess]
POSTPROCESS_TYPE = Optional[Postprocess]
OUTPUT_TYPE = Optional[Union[Output, Mapping[str, Output]]]
OUTPUT_TYPE = Optional[Output]
3 changes: 1 addition & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def fn_resnet(pretrained: bool = True):
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
multi_label: Whether the targets are multi-label or not.
output: A instance of :class:`~flash.core.data.io.output.Output` or a mapping consisting of such
to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
training_strategy: string indicating the training strategy. Adjust if you want to use `learn2learn`
for doing meta-learning research
training_strategy_kwargs: Additional kwargs for setting the training strategy
Expand Down
3 changes: 1 addition & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ class ObjectDetector(AdapterTask):
pretrained: Whether the model from torchvision should be loaded with it's pretrained weights.
Has no effect for custom models.
learning_rate: The learning rate to use for training
output: A instance of :class:`~flash.core.data.io.output.Output` or a mapping consisting of such
to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
kwargs: additional kwargs nessesary for initializing the backbone task
"""

Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SemanticSegmentation(ClassificationTask):
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.IOU`.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
postprocess: :class:`~flash.core.data.process.Postprocess` use for post processing samples.
"""

Expand Down
2 changes: 1 addition & 1 deletion flash/pointcloud/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PointCloudObjectDetector(Task):
by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
learning_rate: The learning rate for the optimizer.
multi_label: If ``True``, this will be treated as a multi-label classification problem.
output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
lambda_loss_cls: The value to scale the loss classification.
lambda_loss_bbox: The value to scale the bounding boxes loss.
lambda_loss_dir: The value to scale the bounding boxes direction loss.
Expand Down
2 changes: 1 addition & 1 deletion flash/pointcloud/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PointCloudSegmentation(ClassificationTask):
by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
learning_rate: The learning rate for the optimizer.
multi_label: If ``True``, this will be treated as a multi-label classification problem.
output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TabularClassifier(ClassificationTask):
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
`pytorch_tabnet <https://dreamquark-ai.github.io/tabnet/_modules/pytorch_tabnet/tab_network.html#TabNet>`_.
"""
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/regression/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TabularRegressor(RegressionTask):
`metric(preds,target)` and return a single scalar tensor.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
`pytorch_tabnet <https://dreamquark-ai.github.io/tabnet/_modules/pytorch_tabnet/tab_network.html#TabNet>`_.
"""
Expand Down
2 changes: 1 addition & 1 deletion flash/template/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TemplateSKLearnClassifier(ClassificationTask):
by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
learning_rate: The learning rate for the optimizer.
multi_label: If ``True``, this will be treated as a multi-label classification problem.
output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

backbones: FlashRegistry = TEMPLATE_BACKBONES
Expand Down
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TextClassifier(ClassificationTask):
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to `1e-3`
multi_label: Whether the targets are multi-label or not.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

Expand Down
3 changes: 1 addition & 2 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ class VideoClassifier(ClassificationTask):
head: either a `nn.Module` or a callable function that converts the features extrated from the backbone
into class log probabilities (assuming default loss function). If `None`, will default to using
a single linear layer.
output: A instance of :class:`~flash.core.data.io.output.Output` that determines how the output
should be serialized e.g. convert the model output into the desired output format when predicting.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

backbones: FlashRegistry = _VIDEO_CLASSIFIER_BACKBONES
Expand Down
58 changes: 4 additions & 54 deletions tests/core/data/io/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,25 @@
import os
from unittest.mock import Mock

import pytest
import torch
from torch.utils.data import DataLoader

from flash.core.classification import Labels
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
from flash.core.data.data_source import LabelsState
from flash.core.data.io.output import Output, OutputMapping
from flash.core.data.io.output import Output
from flash.core.data.process import DefaultPreprocess
from flash.core.data.properties import ProcessState
from flash.core.model import Task
from flash.core.trainer import Trainer


def test_output_enable_disable():
"""Tests that ``Output`` can be enabled and disabled correctly."""

def test_output():
"""Tests basic ``Output`` methods."""
my_output = Output()

assert my_output.transform("test") == "test"
my_output.transform = Mock()

my_output.disable()
assert my_output("test") == "test"
my_output.transform.assert_not_called()

my_output.enable()
my_output.transform = Mock()
my_output("test")
my_output.transform.assert_called_once()

Expand All @@ -65,45 +57,3 @@ def __init__(self):
model = CustomModel.load_from_checkpoint(checkpoint_file)
assert isinstance(model._data_pipeline_state, DataPipelineState)
assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"])


def test_output_mapping():
"""Tests that ``OutputMapping`` correctly passes its inputs to the underlying outputs.

Also checks that state is retrieved / loaded correctly.
"""

output1 = Output()
output1.transform = Mock(return_value="test1")

class output1State(ProcessState):
pass

output2 = Output()
output2.transform = Mock(return_value="test2")

class output2State(ProcessState):
pass

output_mapping = OutputMapping({"key1": output1, "key2": output2})
assert output_mapping({"key1": "output1", "key2": "output2"}) == {"key1": "test1", "key2": "test2"}
output1.transform.assert_called_once_with("output1")
output2.transform.assert_called_once_with("output2")

with pytest.raises(ValueError, match="output must be a mapping"):
output_mapping("not a mapping")

output1_state = output1State()
output2_state = output2State()

output1.set_state(output1_state)
output2.set_state(output2_state)

data_pipeline_state = DataPipelineState()
output_mapping.attach_data_pipeline_state(data_pipeline_state)

assert output1._data_pipeline_state is data_pipeline_state
assert output2._data_pipeline_state is data_pipeline_state

assert data_pipeline_state.get_state(output1State) is output1_state
assert data_pipeline_state.get_state(output2State) is output2_state