Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Refactor image inputs and update to new input object #997

Merged
merged 31 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
06f12ca
Refactor image inputs and update to new input object
ethanwharris Nov 25, 2021
306b058
Fix
ethanwharris Nov 25, 2021
0f8b09b
Fixes
ethanwharris Nov 25, 2021
a342f61
Fix syntax
ethanwharris Nov 25, 2021
b04c264
Updates
ethanwharris Nov 26, 2021
cff5eb4
Fixes
ethanwharris Nov 27, 2021
b7c546b
Add some tests
ethanwharris Nov 27, 2021
55e0871
Add speed test
ethanwharris Nov 29, 2021
0c36958
Add extreme targets case
ethanwharris Nov 29, 2021
d18f7ac
Array types
ethanwharris Nov 29, 2021
93f7096
Updates
ethanwharris Nov 29, 2021
371df14
Updates
ethanwharris Nov 29, 2021
c00355d
Fixes
ethanwharris Nov 29, 2021
eb06578
Fixes
ethanwharris Nov 29, 2021
3a4ffd8
Docs fixes
ethanwharris Nov 29, 2021
b5e3618
Merge branch 'master' into feature/image_inputs
ethanwharris Nov 29, 2021
35038b0
Update CHANGELOG.md
ethanwharris Nov 29, 2021
bd87f1d
Fix
ethanwharris Nov 29, 2021
4edd72d
Fixes
ethanwharris Nov 29, 2021
85a5f9f
Merge branch 'feature/image_inputs' of https://github.com/PyTorchLigh…
ethanwharris Nov 29, 2021
405a2e0
Fixes
ethanwharris Nov 29, 2021
f2b513d
Remove legacy CLI
ethanwharris Nov 29, 2021
4a9357b
Fixes
ethanwharris Nov 29, 2021
619aac1
Fixes
ethanwharris Nov 29, 2021
f859a55
Update flash/core/data/utilities/paths.py
ethanwharris Nov 29, 2021
1c4aac4
Merge branch 'master' into feature/image_inputs
ananyahjha93 Nov 30, 2021
a732e38
Rank zero warn
ethanwharris Nov 30, 2021
e301e9f
Fix audio
ethanwharris Nov 30, 2021
27171d1
Fix audio
ethanwharris Nov 30, 2021
5a00631
Fix active learning
ethanwharris Nov 30, 2021
cd66eb5
Fix AL
ethanwharris Nov 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ ___________________________
~flash.core.data.io.input.MockDataset
~flash.core.data.io.input.NumpyInput
~flash.core.data.io.input.PathsInput
~flash.core.data.io.input.SequenceInput
~flash.core.data.io.input.ClassificationInput
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
~flash.core.data.io.input.TensorInput

.. autosummary::
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ ________________
:template: classtemplate.rst

~data.ImageDeserializer
~data.ImageFiftyOneInput
~data.ImageClassificationFiftyOneInput
~data.ImageNumpyInput
~data.ImagePathsInput
~data.ImageTensorInput
37 changes: 37 additions & 0 deletions flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional

from flash.core.data.io.input import LabelsState
from flash.core.data.io.input_base import Input
from flash.core.data.utilities.labels import get_label_details, LabelDetails
from flash.core.data.utilities.samples import format_targets, to_samples


class ClassificationInput(Input):
def load_data(
self, inputs: List[Any], targets: Optional[List[Any]] = None, label_details: Optional[LabelDetails] = None
) -> List[Dict[str, Any]]:
samples = to_samples(inputs, targets=targets)
if targets is not None:
if self.training:
label_details = get_label_details(targets)
self.set_state(LabelsState.from_label_details(label_details))
self.num_classes = label_details.num_classes
self.label_details = label_details
elif label_details is None:
raise ValueError("In order to format evaluation targets correctly, ``label_details`` must be provided.")

samples = format_targets(samples, label_details)
return samples
12 changes: 9 additions & 3 deletions flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.utilities.labels import LabelDetails
from flash.core.data.utilities.paths import read_csv
from flash.core.data.utils import CurrentRunningStageFuncContext
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
Expand Down Expand Up @@ -141,6 +142,11 @@ class LabelsState(ProcessState):
label."""

labels: Optional[Sequence[str]]
is_multilabel: Optional[bool] = None

@classmethod
def from_label_details(cls, label_details: LabelDetails):
return cls(label_details.labels, is_multilabel=label_details.is_multilabel)


@dataclass(unsafe_hash=True, frozen=True)
Expand Down Expand Up @@ -374,7 +380,7 @@ class SequenceInput(
Generic[SEQUENCE_DATA_TYPE],
Input[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]]],
):
"""The ``SequenceInput`` implements default behaviours for data sources which expect the input to
"""The ``ClassificationInput`` implements default behaviours for data sources which expect the input to
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
:meth:`~flash.core.data.io.input.Input.load_data` to be a sequence of tuples (``(input, target)``
where target can be ``None``).

Expand Down Expand Up @@ -622,7 +628,7 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) ->


class TensorInput(SequenceInput[torch.Tensor]):
"""The ``TensorInput`` is a ``SequenceInput`` which expects the input to
"""The ``TensorInput`` is a ``ClassificationInput`` which expects the input to
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
:meth:`~flash.core.data.io.input.Input.load_data` to be a sequence of ``torch.Tensor`` objects."""

def load_data(
Expand All @@ -637,7 +643,7 @@ def load_data(


class NumpyInput(SequenceInput[np.ndarray]):
"""The ``NumpyInput`` is a ``SequenceInput`` which expects the input to
"""The ``NumpyInput`` is a ``ClassificationInput`` which expects the input to
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
:meth:`~flash.core.data.io.input.Input.load_data` to be a sequence of ``np.ndarray`` objects."""


Expand Down
141 changes: 141 additions & 0 deletions flash/core/data/utilities/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, List, TypeVar, Union

T = TypeVar("T")


class LabelMode(Enum):
"""The ``LabelMode`` Enum describes the different supported label formats for targets in Flash."""

MULTI_LIST = auto()
MUTLI_COMMA_DELIMITED = auto()
SINGLE = auto()

def __add__(self, other: "LabelMode") -> "LabelMode":
"""The purpose of the addition here is to reduce the ``LabelMode`` over multiple targets. If one label mode
is a comma delimite string, then their sum will be also. Otherwise, we expect that both label modes are
consistent.

Raises:
ValueError: If the two label modes could not be resolved to a single mode.
"""
if self is LabelMode.SINGLE and other is LabelMode.SINGLE:
return LabelMode.SINGLE
elif self is LabelMode.MUTLI_COMMA_DELIMITED and other is LabelMode.MUTLI_COMMA_DELIMITED:
return LabelMode.MUTLI_COMMA_DELIMITED
elif self is LabelMode.MUTLI_COMMA_DELIMITED and other is LabelMode.SINGLE:
return LabelMode.MUTLI_COMMA_DELIMITED
elif self is LabelMode.SINGLE and other is LabelMode.MUTLI_COMMA_DELIMITED:
return LabelMode.MUTLI_COMMA_DELIMITED
elif self is LabelMode.MULTI_LIST and other is LabelMode.MULTI_LIST:
return LabelMode.MULTI_LIST
raise ValueError(
"Found inconsistent label modes. All targets should be either: single values, lists of values, or "
"comma-delimited strings."
)

@classmethod
def from_target(cls, target: Any) -> "LabelMode":
"""Determine the ``LabelMode`` for a given target.

Args:
target: A target that is one of: a single target, a list of targets, a comma delimited string.
"""
if isinstance(target, str):
# TODO: This could be a dangerous assumption if people happen to have a label that contains a comma
if "," in target:
return LabelMode.MUTLI_COMMA_DELIMITED
else:
return LabelMode.SINGLE
elif isinstance(target, List):
return LabelMode.MULTI_LIST
return LabelMode.SINGLE


def get_label_mode(targets: List[Any]) -> LabelMode:
"""Aggregate the ``LabelMode`` for a list of targets.

Args:
targets: The list of targets to get the label mode for.

Returns:
The total ``LabelMode`` of the list of targets.
"""
return sum(LabelMode.from_target(target) for target in targets)


@dataclass
class _Token:
"""The ``_Token`` dataclass is used to override the hash of a value to be the hash of it's string
representation.

This allows for using ``set`` with objects such as ``torch.Tensor`` which can have inconsistent hash codes.
"""

value: Any

def __eq__(self, other: "_Token") -> bool:
return str(self.value) == str(other.value)

def __hash__(self) -> int:
return hash(str(self.value))


class LabelDetails:
def __init__(self, labels: List[Any], label_mode: LabelMode):
self.labels = labels
self.label_mode = label_mode

self.label_to_idx = {label: idx for idx, label in enumerate(labels)}
self.is_multilabel = label_mode is LabelMode.MULTI_LIST or label_mode is LabelMode.MUTLI_COMMA_DELIMITED
self.num_classes = len(labels)

def format_target(self, target: Any):
if self.label_mode is LabelMode.MUTLI_COMMA_DELIMITED:
return [self.label_to_idx[t] for t in target.split(",")]
elif self.label_mode is LabelMode.MULTI_LIST:
return [self.label_to_idx[t] for t in target]
return self.label_to_idx[target]


def get_label_details(labels: Union[List[T], List[List[T]]]) -> LabelDetails:
"""Finds and sorts the unique labels in a list of single or multi label targets.

Args:
labels: A list of single or multi-label targets.

Returns:
(labels, is_multilabel): Tuple containing the sorted list of unique targets / labels and a boolean indicating
whether or not the targets were multilabel.
"""
label_mode = get_label_mode(labels)

tokens = []
if label_mode is LabelMode.MUTLI_COMMA_DELIMITED:
for label in labels:
tokens.extend(label.split(","))
elif label_mode is LabelMode.MULTI_LIST:
for label in labels:
tokens.extend(label)
else:
tokens = labels
tokens = map(_Token, tokens)

unique_tokens = list(set(tokens))
labels = list(map(lambda token: token.value, unique_tokens))
labels.sort()
return LabelDetails(labels, label_mode)
44 changes: 30 additions & 14 deletions flash/core/data/utilities/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
# limitations under the License.
import os
import warnings
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union

import pandas as pd

from flash.core.utilities.imports import _PANDAS_GREATER_EQUAL_1_3_0

PATH_TYPE = Union[str, bytes, os.PathLike]

T = TypeVar("T")


# adapted from torchvision:
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L10
Expand Down Expand Up @@ -97,19 +99,16 @@ def isdir(path: Any) -> bool:
return False


def find_classes(dir: PATH_TYPE) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset. Ensures that no class is a subdirectory of another.
def list_subdirs(dir: PATH_TYPE) -> List[str]:
"""List the subdirectories of a given directory.

Args:
dir: Root directory path.
dir: The directory to scan.

Returns:
(classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
The list of subdirectories.
"""
classes = [d.name for d in os.scandir(str(dir)) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
return [d.name for d in os.scandir(str(dir)) if d.is_dir()]


def list_valid_files(
Expand All @@ -132,12 +131,29 @@ def list_valid_files(

if valid_extensions is None:
return paths
return list(
filter(
lambda file: has_file_allowed_extension(file, valid_extensions),
paths,
)
return [path for path in paths if has_file_allowed_extension(path, valid_extensions)]


def filter_valid_files(
files: List[PATH_TYPE], *additional_lists: List[Any], valid_extensions: Optional[Tuple[str, ...]] = None
) -> Tuple[List[Any], ...]:
"""Filter the given list of files and any additional lists to include only the entries that contain a file with
a valid extension.

Args:
files: The list of files to filter by.
additional_lists: Any additional lists to be filtered together with files.
valid_extensions: The tuple of valid file extensions.

Returns:
The filtered lists.
"""
if valid_extensions is None:
return files, *additional_lists
filtered = list(
filter(lambda sample: has_file_allowed_extension(sample[0], valid_extensions), zip(files, *additional_lists))
)
return tuple(zip(*filtered))


def read_csv(file: PATH_TYPE) -> pd.DataFrame:
Expand Down
49 changes: 49 additions & 0 deletions flash/core/data/utilities/samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, TypeVar

from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.labels import LabelDetails

T = TypeVar("T")


def to_samples(inputs: List[Any], targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
"""Package a list of inputs and, optionally, a list of targets in a list of dictionaries (samples).

Args:
inputs: The list of inputs to package as dictionaries.
targets: Optionally provide a list of targets to also be included in the samples.

Returns:
A list of sample dictionaries.
"""
if targets is None:
return [{DataKeys.INPUT: input} for input in inputs]
return [{DataKeys.INPUT: input, DataKeys.TARGET: target} for input, target in zip(inputs, targets)]


def format_targets(samples: List[Dict[str, Any]], label_details: LabelDetails) -> List[Dict[str, Any]]:
"""Use the provided ``LabelDetails`` to format all of the targets in the given list of samples.

Args:
samples: The list of samples containing targets to format.
label_details: The ``LabelDetails`` to format targets with.

Returns:
The list of samples with formatted targets.
"""
for sample in samples:
sample[DataKeys.TARGET] = label_details.format_target(sample[DataKeys.TARGET])
return samples
Loading