Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Refactor image inputs and update to new input object #997

Merged
merged 31 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
06f12ca
Refactor image inputs and update to new input object
ethanwharris Nov 25, 2021
306b058
Fix
ethanwharris Nov 25, 2021
0f8b09b
Fixes
ethanwharris Nov 25, 2021
a342f61
Fix syntax
ethanwharris Nov 25, 2021
b04c264
Updates
ethanwharris Nov 26, 2021
cff5eb4
Fixes
ethanwharris Nov 27, 2021
b7c546b
Add some tests
ethanwharris Nov 27, 2021
55e0871
Add speed test
ethanwharris Nov 29, 2021
0c36958
Add extreme targets case
ethanwharris Nov 29, 2021
d18f7ac
Array types
ethanwharris Nov 29, 2021
93f7096
Updates
ethanwharris Nov 29, 2021
371df14
Updates
ethanwharris Nov 29, 2021
c00355d
Fixes
ethanwharris Nov 29, 2021
eb06578
Fixes
ethanwharris Nov 29, 2021
3a4ffd8
Docs fixes
ethanwharris Nov 29, 2021
b5e3618
Merge branch 'master' into feature/image_inputs
ethanwharris Nov 29, 2021
35038b0
Update CHANGELOG.md
ethanwharris Nov 29, 2021
bd87f1d
Fix
ethanwharris Nov 29, 2021
4edd72d
Fixes
ethanwharris Nov 29, 2021
85a5f9f
Merge branch 'feature/image_inputs' of https://github.com/PyTorchLigh…
ethanwharris Nov 29, 2021
405a2e0
Fixes
ethanwharris Nov 29, 2021
f2b513d
Remove legacy CLI
ethanwharris Nov 29, 2021
4a9357b
Fixes
ethanwharris Nov 29, 2021
619aac1
Fixes
ethanwharris Nov 29, 2021
f859a55
Update flash/core/data/utilities/paths.py
ethanwharris Nov 29, 2021
1c4aac4
Merge branch 'master' into feature/image_inputs
ananyahjha93 Nov 30, 2021
a732e38
Rank zero warn
ethanwharris Nov 30, 2021
e301e9f
Fix audio
ethanwharris Nov 30, 2021
27171d1
Fix audio
ethanwharris Nov 30, 2021
5a00631
Fix active learning
ethanwharris Nov 30, 2021
cd66eb5
Fix AL
ethanwharris Nov 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))

- Added support for comma delimited multi-label targets to the `ImageClassifier` ([#997](https://github.com/PyTorchLightning/lightning-flash/pull/997))

### Changed

- Changed `DataSource` to `Input` ([#929](https://github.com/PyTorchLightning/lightning-flash/pull/929))
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ ___________________________
~flash.core.data.io.input.InputFormat
~flash.core.data.io.input.FiftyOneInput
~flash.core.data.io.input.ImageLabelsMap
~flash.core.data.io.input.LabelsState
~flash.core.data.io.input.MockDataset
~flash.core.data.io.input.NumpyInput
~flash.core.data.io.input.PathsInput
Expand Down
3 changes: 1 addition & 2 deletions docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ______________
:template: classtemplate.rst

~classification.model.ImageClassifier
~classification.data.ImageClassificationFiftyOneInput
~classification.data.ImageClassificationData
~classification.data.ImageClassificationInputTransform

Expand Down Expand Up @@ -140,7 +141,5 @@ ________________
:template: classtemplate.rst

~data.ImageDeserializer
~data.ImageFiftyOneInput
~data.ImageNumpyInput
~data.ImagePathsInput
~data.ImageTensorInput
2 changes: 1 addition & 1 deletion docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ We override our ``TemplateNumpyInput`` so that we can call ``super`` with the da
We perform two additional steps here to improve the user experience:

1. We set the ``num_classes`` attribute on the ``dataset``. If ``num_classes`` is set, it is automatically made available as a property of the :class:`~flash.core.data.data_module.DataModule`.
2. We create and set a :class:`~flash.core.data.io.input.LabelsState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` output, so the user doesn't need to provide them.
2. We create and set a :class:`~flash.core.data.io.input.ClassificationState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` output, so the user doesn't need to provide them.

Here's the code for the ``TemplateSKLearnInput.load_data`` method:

Expand Down
19 changes: 10 additions & 9 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.adapter import AdapterTask
from flash.core.data.io.input import DataKeys, LabelsState
from flash.core.data.io.classification_input import ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.data.io.output import Output
from flash.core.model import Task
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
Expand Down Expand Up @@ -186,7 +187,7 @@ class Labels(Classes):

Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
provided, will attempt to get them from the :class:`.ClassificationState`.
multi_label: If true, treats outputs as multi label logits.
threshold: The threshold to use for multi_label classification.
"""
Expand All @@ -196,15 +197,15 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self._labels = labels

if labels is not None:
self.set_state(LabelsState(labels))
self.set_state(ClassificationState(labels))

def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]:
labels = None

if self._labels is not None:
labels = self._labels
else:
state = self.get_state(LabelsState)
state = self.get_state(ClassificationState)
if state is not None:
labels = state.labels

Expand All @@ -214,7 +215,7 @@ def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]:
if self.multi_label:
return [labels[cls] for cls in classes]
return labels[classes]
rank_zero_warn("No LabelsState was found, this output will act as a Classes output.", UserWarning)
rank_zero_warn("No ClassificationState was found, this output will act as a Classes output.", UserWarning)
return classes


Expand All @@ -223,7 +224,7 @@ class FiftyOneLabels(ClassificationOutput):

Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
provided, will attempt to get them from the :class:`.ClassificationState`.
multi_label: If true, treats outputs as multi label logits.
threshold: A threshold to use to filter candidate labels. In the single label case, predictions below this
threshold will be replaced with None
Expand Down Expand Up @@ -252,7 +253,7 @@ def __init__(
self.return_filepath = return_filepath

if labels is not None:
self.set_state(LabelsState(labels))
self.set_state(ClassificationState(labels))

def transform(
self,
Expand All @@ -266,7 +267,7 @@ def transform(
if self._labels is not None:
labels = self._labels
else:
state = self.get_state(LabelsState)
state = self.get_state(ClassificationState)
if state is not None:
labels = state.labels

Expand Down Expand Up @@ -309,7 +310,7 @@ def transform(
logits=logits,
)
else:
rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning)
rank_zero_warn("No ClassificationState was found, int targets will be used as label strings", UserWarning)

if self.multi_label:
classifications = []
Expand Down
6 changes: 5 additions & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def __init__(

self.set_running_stages()

# Share state between input objects (this will be available in ``load_sample`` but not in ``load_data``)
data_pipeline = self.data_pipeline
data_pipeline.initialize()

@property
def train_dataset(self) -> Optional[Dataset]:
"""This property returns the train dataset."""
Expand Down Expand Up @@ -420,7 +424,7 @@ def num_classes(self) -> Optional[int]:

@property
def multi_label(self) -> Optional[bool]:
"""Property that returns the number of labels of the datamodule if a multilabel task."""
"""Property that returns ``True`` if this ``DataModule`` contains multi-label data."""
multi_label_train = getattr(self.train_dataset, "multi_label", None)
multi_label_val = getattr(self.val_dataset, "multi_label", None)
multi_label_test = getattr(self.test_dataset, "multi_label", None)
Expand Down
79 changes: 79 additions & 0 deletions flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List, Optional, Sequence

from flash.core.data.io.input_base import Input
from flash.core.data.properties import ProcessState
from flash.core.data.utilities.classification import (
get_target_details,
get_target_formatter,
get_target_mode,
TargetFormatter,
)


@dataclass(unsafe_hash=True, frozen=True)
class ClassificationState(ProcessState):
"""A :class:`~flash.core.data.properties.ProcessState` containing ``labels`` (a mapping from class index to
label) and ``num_classes``."""

labels: Optional[Sequence[str]]
num_classes: Optional[int] = None


class ClassificationInput(Input):
"""The ``ClassificationInput`` class provides utility methods for handling classification targets.
:class:`~flash.core.data.io.input_base.Input` objects that extend ``ClassificationInput`` should do the following:

* In the ``load_data`` method, include a call to ``load_target_metadata``. This will determine the format of the
targets and store metadata like ``labels`` and ``num_classes``.
* In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our
tasks.
"""

@property
@lru_cache(maxsize=None)
def target_formatter(self) -> TargetFormatter:
"""Get the :class:`~flash.core.data.utiltiies.classification.TargetFormatter` to use when formatting
targets.

This property uses ``functools.lru_cache`` so that we only instantiate the formatter once.
"""
classification_state = self.get_state(ClassificationState)
return get_target_formatter(self.target_mode, classification_state.labels, classification_state.num_classes)

def load_target_metadata(self, targets: List[Any]) -> None:
"""Determine the target format and store the ``labels`` and ``num_classes``.

Args:
targets: The list of targets.
"""
self.target_mode = get_target_mode(targets)
self.multi_label = self.target_mode.multi_label
if self.training:
self.labels, self.num_classes = get_target_details(targets, self.target_mode)
self.set_state(ClassificationState(self.labels, self.num_classes))

def format_target(self, target: Any) -> Any:
"""Format a single target according to the previously computed target format and metadata.

Args:
target: The target to format.

Returns:
The formatted target.
"""
return self.target_formatter(target)
29 changes: 11 additions & 18 deletions flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@
from tqdm import tqdm

from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset
from flash.core.data.io.classification_input import ClassificationState
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.utilities.paths import read_csv
from flash.core.data.utilities.data_frame import read_csv
from flash.core.data.utils import CurrentRunningStageFuncContext
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
from flash.core.utilities.stages import RunningStage
Expand All @@ -72,7 +73,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
return str(filename).lower().endswith(extensions)


# Credit to the PyTorchVision Team:
Expand Down Expand Up @@ -135,14 +136,6 @@ def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool:
return False


@dataclass(unsafe_hash=True, frozen=True)
class LabelsState(ProcessState):
"""A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to
label."""

labels: Optional[Sequence[str]]


@dataclass(unsafe_hash=True, frozen=True)
class ImageLabelsMap(ProcessState):

Expand Down Expand Up @@ -361,7 +354,7 @@ class DatasetInput(Input[Dataset]):

Args:
labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the
:class:`~flash.core.data.io.input.LabelsState`.
:class:`~flash.core.data.io.input.ClassificationState`.
"""

def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]:
Expand All @@ -380,7 +373,7 @@ class SequenceInput(

Args:
labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the
:class:`~flash.core.data.io.input.LabelsState`.
:class:`~flash.core.data.io.input.ClassificationState`.
"""

def __init__(self, labels: Optional[Sequence[str]] = None):
Expand All @@ -389,7 +382,7 @@ def __init__(self, labels: Optional[Sequence[str]] = None):
self.labels = labels

if self.labels is not None:
self.set_state(LabelsState(self.labels))
self.set_state(ClassificationState(self.labels))

def load_data(
self,
Expand All @@ -415,7 +408,7 @@ class PathsInput(SequenceInput):
Args:
extensions: The file extensions supported by this data source (e.g. ``(".jpg", ".png")``).
labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the
:class:`~flash.core.data.io.input.LabelsState`.
:class:`~flash.core.data.io.input.ClassificationState`.
"""

def __init__(
Expand Down Expand Up @@ -459,7 +452,7 @@ def load_data(
classes, class_to_idx = self.find_classes(data)
if not classes:
return self.predict_load_data(data)
self.set_state(LabelsState(classes))
self.set_state(ClassificationState(classes))

if dataset is not None:
dataset.num_classes = len(classes)
Expand Down Expand Up @@ -577,17 +570,17 @@ def load_data(
if isinstance(target_keys, List):
dataset.multi_label = True
dataset.num_classes = len(target_keys)
self.set_state(LabelsState(target_keys))
self.set_state(ClassificationState(target_keys))
data_frame = data_frame.apply(partial(self._resolve_multi_target, target_keys), axis=1)
target_keys = target_keys[0]
else:
dataset.multi_label = False
if self.training:
labels = list(sorted(data_frame[target_keys].unique()))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))
self.set_state(ClassificationState(labels))

labels = self.get_state(LabelsState)
labels = self.get_state(ClassificationState)

if labels is not None:
labels = labels.labels
Expand Down
3 changes: 2 additions & 1 deletion flash/core/data/io/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import functools
import os
import sys
from copy import copy
from typing import Any, cast, Dict, Iterable, MutableMapping, Optional, Sequence, Tuple, Union

from torch.utils.data import Dataset
Expand Down Expand Up @@ -147,7 +148,7 @@ def _call_load_sample(self, sample: Any) -> Any:
InputBase,
),
)
return load_sample(sample)
return load_sample(copy(sample))

@staticmethod
def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]:
Expand Down
Loading