From 735740e4cd63e404ae70b4ecd37fcc7f64547cb2 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 28 Apr 2021 14:24:43 +0100 Subject: [PATCH 01/78] Initial commit --- flash/core/classification.py | 13 +-- flash/data/data_source.py | 156 ++++++++++++++++++++++++++++ flash/data/process.py | 36 +++---- flash/vision/classification/data.py | 60 +++++------ 4 files changed, 209 insertions(+), 56 deletions(-) create mode 100644 flash/data/data_source.py diff --git a/flash/core/classification.py b/flash/core/classification.py index 346905b823..fbbc13bc8e 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -20,7 +20,8 @@ from pytorch_lightning.utilities import rank_zero_warn from flash.core.model import Task -from flash.data.process import ProcessState, Serializer +from flash.data.data_source import LabelsState +from flash.data.process import Serializer def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -28,12 +29,6 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch. return F.binary_cross_entropy_with_logits(x, y.float()) -@dataclass(unsafe_hash=True, frozen=True) -class ClassificationState(ProcessState): - - labels: Optional[List[str]] - - class ClassificationTask(Task): def __init__( @@ -140,7 +135,7 @@ class Labels(Classes): def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False, threshold: float = 0.5): super().__init__(multi_label=multi_label, threshold=threshold) self._labels = labels - self.set_state(ClassificationState(labels)) + self.set_state(LabelsState(labels)) def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: labels = None @@ -148,7 +143,7 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: if self._labels is not None: labels = self._labels else: - state = self.get_state(ClassificationState) + state = self.get_state(LabelsState) if state is not None: labels = state.labels diff --git a/flash/data/data_source.py b/flash/data/data_source.py new file mode 100644 index 0000000000..11dbd19305 --- /dev/null +++ b/flash/data/data_source.py @@ -0,0 +1,156 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import pathlib +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, Union + +import numpy as np +from torch.nn import Module +from torchvision.datasets.folder import has_file_allowed_extension, make_dataset + +from flash.data.process import ProcessState, Properties + + +@dataclass(unsafe_hash=True, frozen=True) +class LabelsState(ProcessState): + + labels: Optional[Sequence[str]] + + +class DataSource(Properties, Module, ABC): + + def __init__( + self, + train_data: Optional[Any] = None, + val_data: Optional[Any] = None, + test_data: Optional[Any] = None, + predict_data: Optional[Any] = None, + ): + super().__init__() + + self.train_data = train_data + self.val_data = val_data + self.test_data = test_data + self.predict_data = predict_data + + def train_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + return self.load_data(self.train_data, dataset) + + def val_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + return self.load_data(self.val_data, dataset) + + def test_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + return self.load_data(self.test_data, dataset) + + def predict_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + return self.load_data(self.predict_data, dataset) + + @abstractmethod + def load_data( + self, + data: Any, + dataset: Optional[Any] = None + ) -> Iterable[Mapping[str, Any]]: # TODO: decide what type this should be + """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. + + Example:: + + # data: "." + # output: [("./cat/1.png", 1), ..., ("./dog/10.png", 0)] + + output: Mapping = load_data(data) + + """ + + @abstractmethod + def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + """Loads single sample from dataset""" + + +class SequenceDataSource(DataSource, ABC): + + def __init__( + self, + train_inputs: Optional[Sequence[Any]] = None, + train_targets: Optional[Sequence[Any]] = None, + val_inputs: Optional[Sequence[Any]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_inputs: Optional[Sequence[Any]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_inputs: Optional[Sequence[Any]] = None, + predict_targets: Optional[Sequence[Any]] = None, + labels: Optional[Sequence[str]] = None + ): + super().__init__( + train_data=(train_inputs, train_targets), + val_data=(val_inputs, val_targets), + test_data=(test_inputs, test_targets), + predict_data=(predict_inputs, predict_targets), + ) + + self.labels = labels + + if self.labels is not None: + self.set_state(LabelsState(self.labels)) + + def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + inputs, targets = data + if targets is None: + return [{'input': input} for input in inputs] + return [{'input': input, 'target': target} for input, target in zip(inputs, targets)] + + +class FolderDataSource(DataSource, ABC): + + def __init__( + self, + train_folder: Optional[Union[str, pathlib.Path, list]] = None, + val_folder: Optional[Union[str, pathlib.Path, list]] = None, + test_folder: Optional[Union[str, pathlib.Path, list]] = None, + predict_folder: Optional[Union[str, pathlib.Path, list]] = None, + extensions: Optional[Tuple[str, ...]] = None, + ): + super().__init__( + train_data=train_folder, + val_data=val_folder, + test_data=test_folder, + predict_data=predict_folder, + ) + + self.extensions = extensions + + @staticmethod + def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: + """ + Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. + + Args: + dir: Root directory path. + + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + """ + classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes.sort() + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + classes, class_to_idx = self.find_classes(data) + self.set_state(LabelsState(classes)) + dataset.num_classes = len(classes) + data = make_dataset(data, class_to_idx, extensions=self.extensions) + return [{'input': input, 'target': target} for input, target in data] diff --git a/flash/data/process.py b/flash/data/process.py index e44418f1b3..09a2353e68 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -427,24 +427,24 @@ def add_callbacks(self, callbacks: List['FlashCallback']): _callbacks = [c for c in callbacks if c not in self._callbacks] self._callbacks.extend(_callbacks) - @classmethod - def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Mapping: - """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. - - Example:: - - # data: "." - # output: [("./cat/1.png", 1), ..., ("./dog/10.png", 0)] - - output: Mapping = load_data(data) - - """ - return data - - @classmethod - def load_sample(cls, sample: Any, dataset: Optional[Any] = None) -> Any: - """Loads single sample from dataset""" - return sample + # @classmethod + # def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Mapping: + # """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. + # + # Example:: + # + # # data: "." + # # output: [("./cat/1.png", 1), ..., ("./dog/10.png", 0)] + # + # output: Mapping = load_data(data) + # + # """ + # return data + # + # @classmethod + # def load_sample(cls, sample: Any, dataset: Optional[Any] = None) -> Any: + # """Loads single sample from dataset""" + # return sample def pre_tensor_transform(self, sample: Any) -> Any: """Transforms to apply on a single object.""" diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 9a59032cbe..7d303ea522 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -27,11 +27,12 @@ from torchvision import transforms as T from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset -from flash.core.classification import ClassificationState +# from flash.core.classification import ClassificationState from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule +from flash.data.data_source import LabelsState from flash.data.process import Preprocess from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE @@ -75,21 +76,21 @@ def get_state_dict(self) -> Dict[str, Any]: def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): return cls(**state_dict) - @staticmethod - def _find_classes(dir: str) -> Tuple: - """ - Finds the class folders in a dataset. - Args: - dir: Root directory path. - Returns: - tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - Ensures: - No class is a subdirectory of another. - """ - classes = [d.name for d in os.scandir(dir) if d.is_dir()] - classes.sort() - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - return classes, class_to_idx + # @staticmethod + # def _find_classes(dir: str) -> Tuple: + # """ + # Finds the class folders in a dataset. + # Args: + # dir: Root directory path. + # Returns: + # tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + # Ensures: + # No class is a subdirectory of another. + # """ + # classes = [d.name for d in os.scandir(dir) if d.is_dir()] + # classes.sort() + # class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + # return classes, class_to_idx @staticmethod def _get_predicting_files(samples: Union[Sequence, str]) -> List[str]: @@ -216,25 +217,26 @@ def _load_data_dir( dataset.num_classes = len(classes) return classes, make_dataset(data, class_to_idx, IMG_EXTENSIONS, None) - @classmethod - def _load_data_files_labels(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Any: - _classes = [tmp[1] for tmp in data] - - _classes = torch.stack([ - torch.tensor(int(_cls)) if not isinstance(_cls, torch.Tensor) else _cls.view(-1) for _cls in _classes - ]).unique() - - dataset.num_classes = len(_classes) - - return data + # @classmethod + # def _load_data_files_labels(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Any: + # print('called') + # _classes = [tmp[1] for tmp in data] + # + # _classes = torch.stack([ + # torch.tensor(int(_cls)) if not isinstance(_cls, torch.Tensor) else _cls.view(-1) for _cls in _classes + # ]).unique() + # + # dataset.num_classes = len(_classes) + # + # return data def load_data(self, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable: if isinstance(data, (str, pathlib.Path, list)): classes, data = self._load_data_dir(data=data, dataset=dataset) - state = ClassificationState(classes) + state = LabelsState(classes) self.set_state(state) return data - return self._load_data_files_labels(data=data, dataset=dataset) + # return self._load_data_files_labels(data=data, dataset=dataset) @staticmethod def load_sample(sample) -> Union[Image.Image, torch.Tensor, Tuple[Image.Image, torch.Tensor]]: From be0139775ef7ed5f2e3c209fb0c9f3c483e48522 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 29 Apr 2021 20:54:25 +0100 Subject: [PATCH 02/78] POC Initial commit --- flash/core/model.py | 14 +- flash/data/auto_dataset.py | 140 ++++------- flash/data/callback.py | 3 - flash/data/data_module.py | 220 +++++------------ flash/data/data_pipeline.py | 36 +-- flash/data/data_source.py | 134 ++++++++-- flash/data/process.py | 34 ++- flash/data/transforms.py | 28 +++ flash/vision/classification/data.py | 365 ++++------------------------ flash/vision/data.py | 61 +++++ 10 files changed, 398 insertions(+), 637 deletions(-) create mode 100644 flash/data/transforms.py create mode 100644 flash/vision/data.py diff --git a/flash/core/model.py b/flash/core/model.py index b02b782a53..0150cfd82c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -30,6 +30,7 @@ from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline +from flash.data.data_source import DataSource, DefaultDataSource from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping @@ -110,7 +111,8 @@ def step(self, batch: Any, batch_idx: int) -> Any: """ The training/validation/test step. Override for custom behavior. """ - x, y = batch + x, y = batch['input'], batch['target'] + # x, y = batch y_hat = self(x) output = {"y_hat": y_hat} losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} @@ -154,6 +156,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: def predict( self, x: Any, + data_source: Union[str, DefaultDataSource, DataSource] = DefaultDataSource.FILES, data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ @@ -171,7 +174,13 @@ def predict( data_pipeline = self.build_data_pipeline(data_pipeline) - x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)] + if str(data_source) == data_source: + data_source = DefaultDataSource(data_source) + + if not isinstance(data_source, DataSource): + data_source = data_pipeline._preprocess_pipeline.data_source_of_type(data_source.as_type())() + + x = [x for x in data_source.generate_dataset(x, running_stage, data_pipeline)] x = data_pipeline.worker_preprocessor(running_stage)(x) # switch to self.device when #7188 merge in Lightning x = self.transfer_batch_to_device(x, next(self.parameters()).device) @@ -181,6 +190,7 @@ def predict( return predictions def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = batch['input'] if isinstance(batch, tuple): batch = batch[0] elif isinstance(batch, list): diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 2ba6dd92f4..486ce68772 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from inspect import signature -from typing import Any, Callable, Iterable, Iterator, Optional, TYPE_CHECKING +from typing import Any, Callable, Generic, Iterable, Iterator, Optional, Sequence, TYPE_CHECKING, TypeVar import torch from pytorch_lightning.trainer.states import RunningStage @@ -25,9 +25,12 @@ if TYPE_CHECKING: from flash.data.data_pipeline import DataPipeline + from flash.data.data_source import DataSource +DATA_TYPE = TypeVar('DATA_TYPE') -class BaseAutoDataset: + +class BaseAutoDataset(Generic[DATA_TYPE]): DATASET_KEY = "dataset" """ @@ -38,45 +41,41 @@ class BaseAutoDataset: def __init__( self, - data: Any, - load_data: Optional[Callable] = None, - load_sample: Optional[Callable] = None, + data: DATA_TYPE, + data_source: 'DataSource', + running_stage: RunningStage, data_pipeline: Optional['DataPipeline'] = None, - running_stage: Optional[RunningStage] = None ) -> None: super().__init__() - if load_data or load_sample: - if data_pipeline: - rank_zero_warn( - "``datapipeline`` is specified but load_sample and/or load_data are also specified. " - "Won't use datapipeline" - ) - # initial states - self._load_data_called = False - self._running_stage = None - self.data = data + self.data_source = data_source self.data_pipeline = data_pipeline - self.load_data = load_data - self.load_sample = load_sample - # trigger the setup only if `running_stage` is provided + self._running_stage = None self.running_stage = running_stage @property - def running_stage(self) -> Optional[RunningStage]: + def running_stage(self) -> RunningStage: return self._running_stage @running_stage.setter def running_stage(self, running_stage: RunningStage) -> None: - if self._running_stage != running_stage or (not self._running_stage): - self._running_stage = running_stage - self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self.preprocess) - self._load_sample_context = CurrentRunningStageFuncContext( - self._running_stage, "load_sample", self.preprocess + from flash.data.data_source import DataSource # Hack to avoid circular import TODO: something better than this + + self._running_stage = running_stage + + self._load_sample_context = CurrentRunningStageFuncContext(self.running_stage, "load_sample", self.data_source) + + self.load_sample = getattr( + self.data_source, + self.data_pipeline._resolve_function_hierarchy( + 'load_sample', + self.data_source, + self.running_stage, + DataSource, ) - self._setup(running_stage) + ) @property def preprocess(self) -> Optional[Preprocess]: @@ -89,90 +88,33 @@ def control_flow_callback(self) -> Optional[ControlFlow]: if preprocess is not None: return ControlFlow(preprocess.callbacks) - def _call_load_data(self, data: Any) -> Iterable: - parameters = signature(self.load_data).parameters - if len(parameters) > 1 and self.DATASET_KEY in parameters: - return self.load_data(data, self) - else: - return self.load_data(data) - def _call_load_sample(self, sample: Any) -> Any: - parameters = signature(self.load_sample).parameters - if len(parameters) > 1 and self.DATASET_KEY in parameters: - return self.load_sample(sample, self) - else: - return self.load_sample(sample) - - def _setup(self, stage: Optional[RunningStage]) -> None: - assert not stage or _STAGES_PREFIX[stage] in _STAGES_PREFIX_VALUES - previous_load_data = self.load_data.__code__ if self.load_data else None - - if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage: - self.load_data = getattr( - self.preprocess, - self.data_pipeline._resolve_function_hierarchy('load_data', self.preprocess, stage, Preprocess) - ) - self.load_sample = getattr( - self.preprocess, - self.data_pipeline._resolve_function_hierarchy('load_sample', self.preprocess, stage, Preprocess) - ) - if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called): - if previous_load_data: - rank_zero_warn( - "The load_data function of the Autogenerated Dataset changed. " - "This is not expected! Preloading Data again to ensure compatibility. This may take some time." - ) - self.setup() - self._load_data_called = True - - def setup(self): - raise NotImplementedError - + if self.load_sample: + with self._load_sample_context: + parameters = signature(self.load_sample).parameters + if len(parameters) > 1 and self.DATASET_KEY in parameters: + sample = self.load_sample(sample, self) + else: + sample = self.load_sample(sample) + if self.control_flow_callback: + self.control_flow_callback.on_load_sample(sample, self.running_stage) + return sample -class AutoDataset(BaseAutoDataset, Dataset): - def setup(self): - with self._load_data_context: - self.preprocessed_data = self._call_load_data(self.data) +class AutoDataset(BaseAutoDataset[Sequence[Any]], Dataset): def __getitem__(self, index: int) -> Any: - if not self.load_sample and not self.load_data: - raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") - if self.load_sample: - with self._load_sample_context: - data: Any = self._call_load_sample(self.preprocessed_data[index]) - if self.control_flow_callback: - self.control_flow_callback.on_load_sample(data, self.running_stage) - return data - return self.preprocessed_data[index] + return self._call_load_sample(self.data[index]) def __len__(self) -> int: - if not self.load_sample and not self.load_data: - raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.") - return len(self.preprocessed_data) + return len(self.data) -class IterableAutoDataset(BaseAutoDataset, IterableDataset): - - def setup(self): - with self._load_data_context: - self.dataset = self._call_load_data(self.data) - self.dataset_iter = None +class IterableAutoDataset(BaseAutoDataset[Iterable[Any]], IterableDataset): def __iter__(self): - self.dataset_iter = iter(self.dataset) + self.data_iter = iter(self.data) return self def __next__(self) -> Any: - if not self.load_sample and not self.load_data: - raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") - - data = next(self.dataset_iter) - - if self.load_sample: - with self._load_sample_context: - data: Any = self._call_load_sample(data) - if self.control_flow_callback: - self.control_flow_callback.on_load_sample(data, self.running_stage) - return data - return data + return self._call_load_sample(next(self.data_iter)) diff --git a/flash/data/callback.py b/flash/data/callback.py index a479a6e59e..1221046a31 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -190,9 +190,6 @@ def enable(self): yield self.enabled = False - def attach_to_datamodule(self, datamodule) -> None: - datamodule.data_fetcher = self - def attach_to_preprocess(self, preprocess: 'flash.data.process.Preprocess') -> None: preprocess.add_callbacks([self]) self._preprocess = preprocess diff --git a/flash/data/data_module.py b/flash/data/data_module.py index bcb3787268..c594fcdecb 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import pathlib import platform -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union import numpy as np import pytorch_lightning as pl @@ -29,6 +31,7 @@ from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess +from flash.data.data_source import DataSource, FoldersDataSource from flash.data.splits import SplitDataset from flash.data.utils import _STAGES_PREFIX @@ -53,20 +56,35 @@ class DataModule(pl.LightningDataModule): def __init__( self, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - test_dataset: Optional[Dataset] = None, - predict_dataset: Optional[Dataset] = None, + data_source: DataSource, + preprocess: Optional[Preprocess] = None, + postprocess: Optional[Postprocess] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + val_split: Optional[float] = None, batch_size: int = 1, num_workers: Optional[int] = 0, ) -> None: super().__init__() + + self._preprocess: Optional[Preprocess] = preprocess + self._postprocess: Optional[Postprocess] = postprocess + self._viz: Optional[BaseVisualization] = None + self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() + + # TODO: Preprocess can change + self.data_fetcher.attach_to_preprocess(self.preprocess) + + train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets(self.data_pipeline) + self._train_ds = train_dataset self._val_ds = val_dataset self._test_ds = test_dataset self._predict_ds = predict_dataset + if self._train_ds is not None and (val_split is not None and self._val_ds is None): + self._train_ds, self._val_ds = self._split_train_val(self._train_ds, val_split) + if self._train_ds: self.train_dataloader = self._train_dataloader @@ -89,12 +107,6 @@ def __init__( num_workers = os.cpu_count() self.num_workers = num_workers - self._preprocess: Optional[Preprocess] = None - self._postprocess: Optional[Postprocess] = None - self._viz: Optional[BaseVisualization] = None - self._data_fetcher: Optional[BaseDataFetcher] = None - - # this may also trigger data preloading self.set_running_stages() @property @@ -141,7 +153,7 @@ def data_fetcher(self) -> BaseDataFetcher: def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None: self._data_fetcher = data_fetcher - def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]: + def _reset_iterator(self, stage: str) -> Iterable[Any]: iter_name = f"_{stage}_iter" # num_workers has to be set to 0 to work properly num_workers = self.num_workers @@ -152,7 +164,7 @@ def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]: setattr(self, iter_name, iterator) return iterator - def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], reset: bool = True) -> None: + def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool = True) -> None: """ This function is used to handle transforms profiling for batch visualization. """ @@ -278,11 +290,6 @@ def _predict_dataloader(self) -> DataLoader: collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) ) - def generate_auto_dataset(self, *args, **kwargs): - if all(a is None for a in args) and len(kwargs) == 0: - return None - return self.data_pipeline._generate_auto_dataset(*args, **kwargs) - @property def num_classes(self) -> Optional[int]: return ( @@ -303,52 +310,8 @@ def data_pipeline(self) -> DataPipeline: return DataPipeline(self.preprocess, self.postprocess) @staticmethod - def _check_transforms(transform: Dict[str, Union[Module, Callable]]) -> Dict[str, Union[Module, Callable]]: - if not isinstance(transform, dict): - raise MisconfigurationException( - "Transform should be a dict. Here are the available keys " - f"for your transforms: {DataPipeline.PREPROCESS_FUNCS}." - ) - return transform - - @classmethod - def autogenerate_dataset( - cls, - data: Any, - running_stage: RunningStage, - whole_data_load_fn: Optional[Callable] = None, - per_sample_load_fn: Optional[Callable] = None, - data_pipeline: Optional[DataPipeline] = None, - use_iterable_auto_dataset: bool = False, - ) -> BaseAutoDataset: - """ - This function is used to generate an ``BaseAutoDataset`` from a ``DataPipeline`` if provided - or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly - """ - - preprocess = getattr(data_pipeline, '_preprocess_pipeline', None) - - if whole_data_load_fn is None: - whole_data_load_fn = getattr( - preprocess, - DataPipeline._resolve_function_hierarchy('load_data', preprocess, running_stage, Preprocess) - ) - - if per_sample_load_fn is None: - per_sample_load_fn = getattr( - preprocess, - DataPipeline._resolve_function_hierarchy('load_sample', preprocess, running_stage, Preprocess) - ) - if use_iterable_auto_dataset: - return IterableAutoDataset( - data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage - ) - return BaseAutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage) - - @classmethod def _split_train_val( - cls, - train_dataset: Union[AutoDataset, IterableAutoDataset], + train_dataset: Dataset, val_split: float, ) -> Tuple[Any, Any]: @@ -357,7 +320,7 @@ def _split_train_val( if isinstance(train_dataset, IterableAutoDataset): raise MisconfigurationException( - "`val_split` should be `None` when the dataset is built with an IterativeDataset." + "`val_split` should be `None` when the dataset is built with an IterableDataset." ) train_num_samples = len(train_dataset) @@ -367,113 +330,42 @@ def _split_train_val( return SplitDataset(train_dataset, train_indices), SplitDataset(train_dataset, val_indices) @classmethod - def _generate_dataset_if_possible( - cls, - data: Optional[Any], - running_stage: RunningStage, - whole_data_load_fn: Optional[Callable] = None, - per_sample_load_fn: Optional[Callable] = None, - data_pipeline: Optional[DataPipeline] = None, - use_iterable_auto_dataset: bool = False, - ) -> Optional[BaseAutoDataset]: - if data is None: - return - - if data_pipeline: - return data_pipeline._generate_auto_dataset( - data, - running_stage=running_stage, - use_iterable_auto_dataset=use_iterable_auto_dataset, - ) - - return cls.autogenerate_dataset( - data, - running_stage, - whole_data_load_fn, - per_sample_load_fn, - data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, - ) - - @classmethod - def from_load_data_inputs( + def from_folders( cls, - train_load_data_input: Optional[Any] = None, - val_load_data_input: Optional[Any] = None, - test_load_data_input: Optional[Any] = None, - predict_load_data_input: Optional[Any] = None, + train_folder: Optional[Union[str, pathlib.Path]] = None, + val_folder: Optional[Union[str, pathlib.Path]] = None, + test_folder: Optional[Union[str, pathlib.Path]] = None, + predict_folder: Union[str, pathlib.Path] = None, + train_transform: Optional[Union[str, Dict]] = 'default', + val_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - use_iterable_auto_dataset: bool = False, - seed: int = 42, val_split: Optional[float] = None, - **kwargs, + batch_size: int = 4, + num_workers: Optional[int] = None, ) -> 'DataModule': - """ - This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``. - - Args: - cls: ``DataModule`` subclass - train_load_data_input: Data to be received by the ``train_load_data`` function - from this :class:`~flash.data.process.Preprocess` - val_load_data_input: Data to be received by the ``val_load_data`` function - from this :class:`~flash.data.process.Preprocess` - test_load_data_input: Data to be received by the ``test_load_data`` function - from this :class:`~flash.data.process.Preprocess` - predict_load_data_input: Data to be received by the ``predict_load_data`` function - from this :class:`~flash.data.process.Preprocess` - kwargs: Any extra arguments to instantiate the provided ``DataModule`` - """ - # trick to get data_pipeline from empty DataModule - if preprocess or postprocess: - data_pipeline = DataPipeline( - preprocess or cls(**kwargs).preprocess, - postprocess or cls(**kwargs).postprocess, - ) - else: - data_pipeline = cls(**kwargs).data_pipeline - - data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher() - - data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline) - train_dataset = cls._generate_dataset_if_possible( - train_load_data_input, - running_stage=RunningStage.TRAINING, - data_pipeline=data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, - ) - val_dataset = cls._generate_dataset_if_possible( - val_load_data_input, - running_stage=RunningStage.VALIDATING, - data_pipeline=data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, - ) - test_dataset = cls._generate_dataset_if_possible( - test_load_data_input, - running_stage=RunningStage.TESTING, - data_pipeline=data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, - ) - predict_dataset = cls._generate_dataset_if_possible( - predict_load_data_input, - running_stage=RunningStage.PREDICTING, - data_pipeline=data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, + preprocess = preprocess or cls.preprocess_cls( + train_transform, + val_transform, + test_transform, + predict_transform, ) - if train_dataset is not None and (val_split is not None and val_dataset is None): - train_dataset, val_dataset = cls._split_train_val(train_dataset, val_split) + data_source = preprocess.data_source_of_type(FoldersDataSource)( + train_folder=train_folder, + val_folder=val_folder, + test_folder=test_folder, + predict_folder=predict_folder, + ) - datamodule = cls( - train_dataset=train_dataset, - val_dataset=val_dataset, - test_dataset=test_dataset, - predict_dataset=predict_dataset, - **kwargs + return cls( + data_source, + preprocess, + data_fetcher=data_fetcher, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, ) - datamodule._preprocess = data_pipeline._preprocess_pipeline - datamodule._postprocess = data_pipeline._postprocess_pipeline - data_fetcher.attach_to_datamodule(datamodule) - return datamodule diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 5715840abe..89f152346d 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -506,24 +506,24 @@ def _detach_postprocess_from_model(model: 'Task'): # if any other pipeline is attached which may rely on this! model.predict_step = model.predict_step._original - def _generate_callable_auto_dataset( - self, data: Union[Iterable, Any], running_stage: RunningStage = None - ) -> Callable: - - def fn(): - return self._generate_auto_dataset(data, running_stage=running_stage) - - return fn - - def _generate_auto_dataset( - self, - data: Union[Iterable, Any], - running_stage: RunningStage = None, - use_iterable_auto_dataset: bool = False - ) -> Union[AutoDataset, IterableAutoDataset]: - if use_iterable_auto_dataset: - return IterableAutoDataset(data, data_pipeline=self, running_stage=running_stage) - return AutoDataset(data=data, data_pipeline=self, running_stage=running_stage) + # def _generate_callable_auto_dataset( + # self, data: Union[Iterable, Any], running_stage: RunningStage = None + # ) -> Callable: + # + # def fn(): + # return self._generate_auto_dataset(data, running_stage=running_stage) + # + # return fn + + # def _generate_auto_dataset( + # self, + # data: Union[Iterable, Any], + # running_stage: RunningStage = None, + # use_iterable_auto_dataset: bool = False + # ) -> Union[AutoDataset, IterableAutoDataset]: + # if use_iterable_auto_dataset: + # return IterableAutoDataset(data, data_pipeline=self, running_stage=running_stage) + # return AutoDataset(data=data, data_pipeline=self, running_stage=running_stage) def to_dataloader( self, data: Union[Iterable, Any], auto_collate: Optional[bool] = None, **loader_kwargs diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 11dbd19305..29a08a6713 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -15,13 +15,26 @@ import pathlib from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from enum import Enum +from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Sized, Tuple, Type, TypeVar, Union import numpy as np +from pytorch_lightning.trainer.states import RunningStage from torch.nn import Module from torchvision.datasets.folder import has_file_allowed_extension, make_dataset +from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset +from flash.data.data_pipeline import DataPipeline from flash.data.process import ProcessState, Properties +from flash.data.utils import _STAGES_PREFIX, CurrentRunningStageFuncContext + + +def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool: + try: + len(data) + return True + except (TypeError, NotImplementedError): + return False @dataclass(unsafe_hash=True, frozen=True) @@ -30,6 +43,18 @@ class LabelsState(ProcessState): labels: Optional[Sequence[str]] +class MockDataset: + + def __init__(self): + self.metadata = {} + + def __setattr__(self, key, value): + if key != 'metadata': + self.metadata[key] = value + else: + object.__setattr__(self, key, value) + + class DataSource(Properties, Module, ABC): def __init__( @@ -46,18 +71,6 @@ def __init__( self.test_data = test_data self.predict_data = predict_data - def train_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: - return self.load_data(self.train_data, dataset) - - def val_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: - return self.load_data(self.val_data, dataset) - - def test_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: - return self.load_data(self.test_data, dataset) - - def predict_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: - return self.load_data(self.predict_data, dataset) - @abstractmethod def load_data( self, @@ -79,6 +92,55 @@ def load_data( def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: """Loads single sample from dataset""" + def to_datasets(self, data_pipeline: DataPipeline) -> Tuple[Optional[BaseAutoDataset], ...]: + train_dataset = self._generate_dataset_if_possible(RunningStage.TRAINING, data_pipeline) + val_dataset = self._generate_dataset_if_possible(RunningStage.VALIDATING, data_pipeline) + test_dataset = self._generate_dataset_if_possible(RunningStage.TESTING, data_pipeline) + predict_dataset = self._generate_dataset_if_possible(RunningStage.PREDICTING, data_pipeline) + return train_dataset, val_dataset, test_dataset, predict_dataset + + def _generate_dataset_if_possible( + self, + running_stage: RunningStage, + data_pipeline: DataPipeline, + ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: + data = getattr(self, f"{_STAGES_PREFIX[running_stage]}_data", None) + if data is not None: + return self.generate_dataset(data, running_stage, data_pipeline) + + def generate_dataset( + self, + data, + running_stage: RunningStage, + data_pipeline: DataPipeline, + ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: + mock_dataset = MockDataset() + with CurrentRunningStageFuncContext(running_stage, "load_data", self): + data = self.load_data(data, mock_dataset) # TODO: Should actually resolve this + + if has_len(data): + dataset = AutoDataset(data, self, running_stage, data_pipeline) + else: + dataset = IterableAutoDataset(data, self, running_stage, data_pipeline) + dataset.__dict__.update(mock_dataset.metadata) + return dataset + + +T = TypeVar("T") + + +class DefaultDataSource(Enum): # TODO: This could be replaced with a data source registry that the user can add to + + FOLDERS = "folders" + FILES = "files" + + def as_type(self) -> Type[DataSource]: + _data_source_types = { + DefaultDataSource.FOLDERS: FoldersDataSource, + DefaultDataSource.FILES: FilesDataSource, + } + return _data_source_types[self] + class SequenceDataSource(DataSource, ABC): @@ -91,14 +153,13 @@ def __init__( test_inputs: Optional[Sequence[Any]] = None, test_targets: Optional[Sequence[Any]] = None, predict_inputs: Optional[Sequence[Any]] = None, - predict_targets: Optional[Sequence[Any]] = None, labels: Optional[Sequence[str]] = None ): super().__init__( train_data=(train_inputs, train_targets), val_data=(val_inputs, val_targets), test_data=(test_inputs, test_targets), - predict_data=(predict_inputs, predict_targets), + predict_data=(predict_inputs, None), ) self.labels = labels @@ -113,7 +174,7 @@ def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mappin return [{'input': input, 'target': target} for input, target in zip(inputs, targets)] -class FolderDataSource(DataSource, ABC): +class FoldersDataSource(DataSource, ABC): def __init__( self, @@ -150,7 +211,48 @@ def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: classes, class_to_idx = self.find_classes(data) + if not classes: + files = [os.path.join(data, file) for file in os.listdir(data)] + return [{ + 'input': file + } for file in filter( + lambda file: has_file_allowed_extension(file, self.extensions), + files, + )] self.set_state(LabelsState(classes)) dataset.num_classes = len(classes) data = make_dataset(data, class_to_idx, extensions=self.extensions) return [{'input': input, 'target': target} for input, target in data] + + +class FilesDataSource(DataSource, ABC): + + def __init__( + self, + train_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, + train_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, + val_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, + val_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, + test_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, + test_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, + predict_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, + extensions: Optional[Tuple[str, ...]] = None, + ): + super().__init__( + train_data=(train_files, train_targets), + val_data=(val_files, val_targets), + test_data=(test_files, test_targets), + predict_data=(predict_files, None), + ) + + self.extensions = extensions + + def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + # TODO: Bring back the code to work out how many classes there are + if isinstance(data, tuple): + files, targets = data + else: + files, targets = data, None # TODO: Sort this out + if not targets: + return [{'input': input} for input in files] + return [{'input': file, 'target': target} for file, target in zip(files, targets)] diff --git a/flash/data/process.py b/flash/data/process.py index 09a2353e68..4d4368712a 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from flash.data.data_pipeline import DataPipelineState + from flash.data.data_source import DataSource @dataclass(unsafe_hash=True, frozen=True) @@ -427,25 +428,6 @@ def add_callbacks(self, callbacks: List['FlashCallback']): _callbacks = [c for c in callbacks if c not in self._callbacks] self._callbacks.extend(_callbacks) - # @classmethod - # def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Mapping: - # """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. - # - # Example:: - # - # # data: "." - # # output: [("./cat/1.png", 1), ..., ("./dog/10.png", 0)] - # - # output: Mapping = load_data(data) - # - # """ - # return data - # - # @classmethod - # def load_sample(cls, sample: Any, dataset: Optional[Any] = None) -> Any: - # """Loads single sample from dataset""" - # return sample - def pre_tensor_transform(self, sample: Any) -> Any: """Transforms to apply on a single object.""" return sample @@ -498,8 +480,22 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: return batch +T = TypeVar("T") + + class DefaultPreprocess(Preprocess): + data_sources: List[Type['DataSource']] = [] # TODO: Make this a property + + # TODO: Doesn't need to be a classmethod + @classmethod + def data_source_of_type(cls, data_source_type: Type[T]) -> Optional[Type[T]]: + data_sources = cls.data_sources + for data_source in data_sources: + if issubclass(data_source, data_source_type): + return data_source + return None + def get_state_dict(self) -> Dict[str, Any]: return {} diff --git a/flash/data/transforms.py b/flash/data/transforms.py new file mode 100644 index 0000000000..3a0c12bd1b --- /dev/null +++ b/flash/data/transforms.py @@ -0,0 +1,28 @@ +from typing import Any, Mapping, Sequence, Union + +from torch import nn + +from flash.data.utils import convert_to_modules + + +class ApplyToKeys(nn.Sequential): + + def __init__(self, keys: Union[str, Sequence[str]], *args): + super().__init__(*[convert_to_modules(arg) for arg in args]) + if str(keys) == keys: + keys = [keys] + self.keys = keys + + def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: + inputs = [x[key] for key in filter(lambda key: key in x, self.keys)] + if len(inputs) > 0: + outputs = super().forward(*inputs) + if not isinstance(outputs, tuple): + outputs = (outputs, ) + + result = {} + result.update(x) + for i, key in enumerate(self.keys): + result[key] = outputs[i] + return result + return x diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 7d303ea522..0e6182af87 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -33,8 +33,10 @@ from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.data_source import LabelsState -from flash.data.process import Preprocess +from flash.data.process import DefaultPreprocess +from flash.data.transforms import ApplyToKeys from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE +from flash.vision.data import ImageFilesDataSource, ImageFoldersDataSource if _KORNIA_AVAILABLE: import kornia as K @@ -45,8 +47,9 @@ plt = None -class ImageClassificationPreprocess(Preprocess): +class ImageClassificationPreprocess(DefaultPreprocess): + data_sources = [ImageFoldersDataSource, ImageFilesDataSource] to_tensor = T.ToTensor() def __init__( @@ -76,79 +79,67 @@ def get_state_dict(self) -> Dict[str, Any]: def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): return cls(**state_dict) - # @staticmethod - # def _find_classes(dir: str) -> Tuple: - # """ - # Finds the class folders in a dataset. - # Args: - # dir: Root directory path. - # Returns: - # tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - # Ensures: - # No class is a subdirectory of another. - # """ - # classes = [d.name for d in os.scandir(dir) if d.is_dir()] - # classes.sort() - # class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - # return classes, class_to_idx - - @staticmethod - def _get_predicting_files(samples: Union[Sequence, str]) -> List[str]: - files = [] - if isinstance(samples, str): - samples = [samples] - - if isinstance(samples, (list, tuple)) and all(os.path.isdir(s) for s in samples): - files = [os.path.join(sp, f) for sp in samples for f in os.listdir(sp)] - - elif isinstance(samples, (list, tuple)) and all(os.path.isfile(s) for s in samples): - files = samples - - files = list(filter(lambda p: has_file_allowed_extension(p, IMG_EXTENSIONS), files)) - - return files - def default_train_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { - "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": nn.Sequential( + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys('target', torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + 'input', # TODO (Edgar): replace with resize once kornia is fixed K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), K.augmentation.RandomHorizontalFlip(), ), - "per_batch_transform_on_device": nn.Sequential( + "per_batch_transform_on_device": ApplyToKeys( + 'input', K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ) } else: - from torchvision import transforms as T # noqa F811 return { - "pre_tensor_transform": nn.Sequential(T.Resize(image_size), T.RandomHorizontalFlip()), - "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + "pre_tensor_transform": ApplyToKeys('input', T.Resize(image_size), T.RandomHorizontalFlip()), + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys('target', torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + 'input', + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ), } def default_val_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { - "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": nn.Sequential( + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys('target', torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + 'input', # TODO (Edgar): replace with resize once kornia is fixed K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), ), - "per_batch_transform_on_device": nn.Sequential( + "per_batch_transform_on_device": ApplyToKeys( + 'input', K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ) } else: - from torchvision import transforms as T # noqa F811 return { - "pre_tensor_transform": T.Compose([T.Resize(image_size)]), - "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + "pre_tensor_transform": ApplyToKeys('input', T.Resize(image_size)), + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys('target', torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + 'input', + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ), } def _resolve_transforms( @@ -179,112 +170,18 @@ def _resolve_transforms( predict_transform, ) - @classmethod - def _load_data_dir( - cls, - data: Any, - dataset: Optional[AutoDataset] = None, - ) -> Tuple[Optional[List[str]], List[Tuple[str, int]]]: - if isinstance(data, list): - # TODO: define num_classes elsewhere. This is a bad assumption since the list of - # labels might not contain the complete set of ids so that you can infer the total - # number of classes to train in your dataset. - dataset.num_classes = len(data) - out: List[Tuple[str, int]] = [] - for p, label in data: - if os.path.isdir(p): - # TODO: there is an issue here when a path is provided along with labels. - # os.listdir cannot assure the same file order as the passed labels list. - files_list: List[str] = os.listdir(p) - if len(files_list) > 1: - raise ValueError( - f"The provided directory contains more than one file." - f"Directory: {p} -> Contains: {files_list}" - ) - for f in files_list: - if has_file_allowed_extension(f, IMG_EXTENSIONS): - out.append([os.path.join(p, f), label]) - elif os.path.isfile(p) and has_file_allowed_extension(str(p), IMG_EXTENSIONS): - out.append([p, label]) - else: - raise TypeError(f"Unexpected file path type: {p}.") - return None, out - else: - classes, class_to_idx = cls._find_classes(data) - # TODO: define num_classes elsewhere. This is a bad assumption since the list of - # labels might not contain the complete set of ids so that you can infer the total - # number of classes to train in your dataset. - dataset.num_classes = len(classes) - return classes, make_dataset(data, class_to_idx, IMG_EXTENSIONS, None) - - # @classmethod - # def _load_data_files_labels(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Any: - # print('called') - # _classes = [tmp[1] for tmp in data] - # - # _classes = torch.stack([ - # torch.tensor(int(_cls)) if not isinstance(_cls, torch.Tensor) else _cls.view(-1) for _cls in _classes - # ]).unique() - # - # dataset.num_classes = len(_classes) - # - # return data - - def load_data(self, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable: - if isinstance(data, (str, pathlib.Path, list)): - classes, data = self._load_data_dir(data=data, dataset=dataset) - state = LabelsState(classes) - self.set_state(state) - return data - # return self._load_data_files_labels(data=data, dataset=dataset) - - @staticmethod - def load_sample(sample) -> Union[Image.Image, torch.Tensor, Tuple[Image.Image, torch.Tensor]]: - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - if isinstance(sample, torch.Tensor): - out: torch.Tensor = sample - return out - - path: str = "" - if isinstance(sample, (tuple, list)): - path = sample[0] - sample = list(sample) - else: - path = sample - - with open(path, "rb") as f, Image.open(f) as img: - img_out: Image.Image = img.convert("RGB") - - if isinstance(sample, list): - # return a tuple with the PIL image and tensor with the labels. - # returning the tensor helps later to easily collate the batch - # for single/multi label at the same time. - out: Tuple[Image.Image, torch.Tensor] = (img_out, torch.as_tensor(sample[1])) - return out - - return img_out - - @classmethod - def predict_load_data(cls, samples: Any) -> Iterable: - if isinstance(samples, torch.Tensor): - return samples - return cls._get_predicting_files(samples) - - def collate(self, samples: Sequence) -> Any: - _samples = [] + def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: # todo: Kornia transforms add batch dimension which need to be removed for sample in samples: - if isinstance(sample, tuple): - sample = (sample[0].squeeze(0), ) + sample[1:] - else: - sample = sample.squeeze(0) - _samples.append(sample) - return default_collate(_samples) + for key in sample.keys(): + if torch.is_tensor(sample[key]): + sample[key] = sample[key].squeeze(0) + return default_collate(samples) def common_step(self, sample: Any) -> Any: - if isinstance(sample, (list, tuple)): - source, target = sample - return self.current_transform(source), target + # if isinstance(sample, (list, tuple)): + # source, target = sample + # return self.current_transform(source), target return self.current_transform(sample) def pre_tensor_transform(self, sample: Any) -> Any: @@ -317,6 +214,8 @@ def per_batch_transform_on_device(self, sample: Any) -> Any: class ImageClassificationData(DataModule): """Data module for image classification tasks.""" + preprocess_cls = ImageClassificationPreprocess + def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value @@ -332,172 +231,6 @@ def _get_num_classes(self, dataset: torch.utils.data.Dataset): return num_classes - @classmethod - def from_folders( - cls, - train_folder: Optional[Union[str, pathlib.Path]] = None, - val_folder: Optional[Union[str, pathlib.Path]] = None, - test_folder: Optional[Union[str, pathlib.Path]] = None, - predict_folder: Union[str, pathlib.Path] = None, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', - batch_size: int = 4, - num_workers: Optional[int] = None, - data_fetcher: BaseDataFetcher = None, - preprocess: Optional[Preprocess] = None, - **kwargs, - ) -> 'DataModule': - """ - Creates a ImageClassificationData object from folders of images arranged in this way: :: - - train/dog/xxx.png - train/dog/xxy.png - train/dog/xxz.png - train/cat/123.png - train/cat/nsdf3.png - train/cat/asd932.png - - Args: - train_folder: Path to training folder. Default: None. - val_folder: Path to validation folder. Default: None. - test_folder: Path to test folder. Default: None. - predict_folder: Path to predict folder. Default: None. - val_transform: Image transform to use for validation and test set. - train_transform: Image transform to use for training set. - val_transform: Image transform to use for validation set. - test_transform: Image transform to use for test set. - predict_transform: Image transform to use for predict set. - batch_size: Batch size for data loading. - num_workers: The number of workers to use for parallelized loading. - Defaults to ``None`` which equals the number of available CPU threads. - - Returns: - ImageClassificationData: the constructed data module - - Examples: - >>> img_data = ImageClassificationData.from_folders("train/") # doctest: +SKIP - - """ - preprocess = preprocess or ImageClassificationPreprocess( - train_transform, - val_transform, - test_transform, - predict_transform, - ) - - return cls.from_load_data_inputs( - train_load_data_input=train_folder, - val_load_data_input=val_folder, - test_load_data_input=test_folder, - predict_load_data_input=predict_folder, - batch_size=batch_size, - num_workers=num_workers, - data_fetcher=data_fetcher, - preprocess=preprocess, - **kwargs, - ) - - @classmethod - def from_filepaths( - cls, - train_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - train_labels: Optional[Sequence] = None, - val_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - val_labels: Optional[Sequence] = None, - test_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - test_labels: Optional[Sequence] = None, - predict_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - train_transform: Union[str, Dict] = 'default', - val_transform: Union[str, Dict] = 'default', - test_transform: Union[str, Dict] = 'default', - predict_transform: Union[str, Dict] = 'default', - image_size: Tuple[int, int] = (196, 196), - batch_size: int = 64, - num_workers: Optional[int] = None, - seed: Optional[int] = 42, - data_fetcher: BaseDataFetcher = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - **kwargs, - ) -> 'ImageClassificationData': - """ - Creates a ImageClassificationData object from folders of images arranged in this way: :: - - folder/dog_xxx.png - folder/dog_xxy.png - folder/dog_xxz.png - folder/cat_123.png - folder/cat_nsdf3.png - folder/cat_asd932_.png - - Args: - - train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. - train_labels: Sequence of labels for training dataset. Defaults to ``None``. - val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. - val_labels: Sequence of labels for validation dataset. Defaults to ``None``. - test_filepaths: String or sequence of file paths for test dataset. Defaults to ``None``. - test_labels: Sequence of labels for test dataset. Defaults to ``None``. - train_transform: Image transform to use for the train set. Defaults to ``default``, which loads imagenet - transforms. - val_transform: Image transform to use for the validation set. Defaults to ``default``, which loads - imagenet transforms. - test_transform: Image transform to use for the test set. Defaults to ``default``, which loads imagenet - transforms. - predict_transform: Image transform to use for the predict set. Defaults to ``default``, which loads imagenet - transforms. - batch_size: The batchsize to use for parallel loading. Defaults to ``64``. - num_workers: The number of workers to use for parallelized loading. - Defaults to ``None`` which equals the number of available CPU threads. - seed: Used for the train/val splits. - - Returns: - - ImageClassificationData: The constructed data module. - """ - # enable passing in a string which loads all files in that folder as a list - if isinstance(train_filepaths, str): - if os.path.isdir(train_filepaths): - train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)] - else: - train_filepaths = [train_filepaths] - - if isinstance(val_filepaths, str): - if os.path.isdir(val_filepaths): - val_filepaths = [os.path.join(val_filepaths, x) for x in os.listdir(val_filepaths)] - else: - val_filepaths = [val_filepaths] - - if isinstance(test_filepaths, str): - if os.path.isdir(test_filepaths): - test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)] - else: - test_filepaths = [test_filepaths] - - preprocess = preprocess or ImageClassificationPreprocess( - train_transform, - val_transform, - test_transform, - predict_transform, - image_size=image_size, - ) - - return cls.from_load_data_inputs( - train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, - val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, - test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, - predict_load_data_input=predict_filepaths, - batch_size=batch_size, - num_workers=num_workers, - data_fetcher=data_fetcher, - preprocess=preprocess, - seed=seed, - val_split=val_split, - **kwargs - ) - class MatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. diff --git a/flash/vision/data.py b/flash/vision/data.py new file mode 100644 index 0000000000..f41c05c6c0 --- /dev/null +++ b/flash/vision/data.py @@ -0,0 +1,61 @@ +import pathlib +from typing import Any, Iterable, Mapping, Optional, Sequence, Union + +from PIL import Image +from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS + +from flash.data.data_source import FilesDataSource, FoldersDataSource, SequenceDataSource + + +class ImageFoldersDataSource(FoldersDataSource): + + def __init__( + self, + train_folder: Optional[Union[str, pathlib.Path, list]] = None, + val_folder: Optional[Union[str, pathlib.Path, list]] = None, + test_folder: Optional[Union[str, pathlib.Path, list]] = None, + predict_folder: Optional[Union[str, pathlib.Path, list]] = None, + ): + super().__init__( + train_folder=train_folder, + val_folder=val_folder, + test_folder=test_folder, + predict_folder=predict_folder, + extensions=IMG_EXTENSIONS, + ) + + def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + result = {} # TODO: this is required to avoid a memory leak, can we automate this? + result.update(sample) + result['input'] = default_loader(sample['input']) + return result + + +class ImageFilesDataSource(FilesDataSource): + + def __init__( + self, + train_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, + train_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, + val_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, + val_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, + test_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, + test_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, + predict_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, + ): + super().__init__( # TODO: This feels like it can be simplified + train_files=train_files, + train_targets=train_targets, + val_files=val_files, + val_targets=val_targets, + test_files=test_files, + test_targets=test_targets, + predict_files=predict_files, + extensions=IMG_EXTENSIONS + ) + + def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + result = {} # TODO: this is required to avoid a memory leak, can we automate this? + result.update(sample) + result['input'] = default_loader(sample['input']) + return result From 214df853b7ae9737def10404d81d2450a2b4683e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 29 Apr 2021 21:02:56 +0100 Subject: [PATCH 03/78] Remove unused code --- flash/vision/classification/data.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 0e6182af87..1b5b6d0dcd 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import pathlib -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -22,17 +21,12 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn -from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate from torchvision import transforms as T -from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset -# from flash.core.classification import ClassificationState -from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule -from flash.data.data_source import LabelsState from flash.data.process import DefaultPreprocess from flash.data.transforms import ApplyToKeys from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE @@ -224,13 +218,6 @@ def set_block_viz_window(self, value: bool) -> None: def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return MatplotlibVisualization(*args, **kwargs) - def _get_num_classes(self, dataset: torch.utils.data.Dataset): - num_classes = self.get_dataset_attribute(dataset, "num_classes", None) - if num_classes is None: - num_classes = torch.tensor([dataset[idx][1] for idx in range(len(dataset))]).unique().numel() - - return num_classes - class MatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. From 8f93bfb54bfa6dc9890cacf2197e57cec4097ca3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 30 Apr 2021 08:27:50 +0100 Subject: [PATCH 04/78] Some fixes --- flash/core/classification.py | 3 ++- flash/data/data_source.py | 25 +++++++++++++++---------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index fbbc13bc8e..a716ae21e2 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -135,7 +135,8 @@ class Labels(Classes): def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False, threshold: float = 0.5): super().__init__(multi_label=multi_label, threshold=threshold) self._labels = labels - self.set_state(LabelsState(labels)) + if labels is not None: + self.set_state(LabelsState(labels)) def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: labels = None diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 29a08a6713..314401637d 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -16,16 +16,14 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Sized, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union -import numpy as np from pytorch_lightning.trainer.states import RunningStage -from torch.nn import Module from torchvision.datasets.folder import has_file_allowed_extension, make_dataset from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset from flash.data.data_pipeline import DataPipeline -from flash.data.process import ProcessState, Properties +from flash.data.process import Preprocess, ProcessState from flash.data.utils import _STAGES_PREFIX, CurrentRunningStageFuncContext @@ -55,7 +53,7 @@ def __setattr__(self, key, value): object.__setattr__(self, key, value) -class DataSource(Properties, Module, ABC): +class DataSource(ABC): def __init__( self, @@ -71,6 +69,12 @@ def __init__( self.test_data = test_data self.predict_data = predict_data + self._preprocess: Optional[Preprocess] = None + + @property + def preprocess(self) -> Optional[Preprocess]: + return self._preprocess + @abstractmethod def load_data( self, @@ -93,6 +97,9 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) """Loads single sample from dataset""" def to_datasets(self, data_pipeline: DataPipeline) -> Tuple[Optional[BaseAutoDataset], ...]: + # attach preprocess + self._preprocess = data_pipeline._preprocess_pipeline + train_dataset = self._generate_dataset_if_possible(RunningStage.TRAINING, data_pipeline) val_dataset = self._generate_dataset_if_possible(RunningStage.VALIDATING, data_pipeline) test_dataset = self._generate_dataset_if_possible(RunningStage.TESTING, data_pipeline) @@ -126,9 +133,6 @@ def generate_dataset( return dataset -T = TypeVar("T") - - class DefaultDataSource(Enum): # TODO: This could be replaced with a data source registry that the user can add to FOLDERS = "folders" @@ -165,7 +169,7 @@ def __init__( self.labels = labels if self.labels is not None: - self.set_state(LabelsState(self.labels)) + self.preprocess.set_state(LabelsState(self.labels)) def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: inputs, targets = data @@ -219,7 +223,8 @@ def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mappin lambda file: has_file_allowed_extension(file, self.extensions), files, )] - self.set_state(LabelsState(classes)) + else: + self.preprocess.set_state(LabelsState(classes)) dataset.num_classes = len(classes) data = make_dataset(data, class_to_idx, extensions=self.extensions) return [{'input': input, 'target': target} for input, target in data] From e8ee4c095c32cbfd9522333226fc9f577ac69725 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 30 Apr 2021 18:12:44 +0100 Subject: [PATCH 05/78] Simplify data source --- flash/core/model.py | 40 +++++++----- flash/data/auto_dataset.py | 26 ++------ flash/data/batch.py | 2 + flash/data/data_module.py | 9 ++- flash/data/data_pipeline.py | 34 ++++------ flash/data/data_source.py | 63 +++++++++---------- flash/vision/classification/data.py | 22 +++---- .../finetuning/image_classification.py | 6 +- 8 files changed, 92 insertions(+), 110 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 0150cfd82c..4d929f5cd1 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -20,7 +20,6 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim.lr_scheduler import _LRScheduler @@ -29,7 +28,7 @@ from flash.core.registry import FlashRegistry from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict -from flash.data.data_pipeline import DataPipeline +from flash.data.data_pipeline import DataPipeline, DataPipelineState from flash.data.data_source import DataSource, DefaultDataSource from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping @@ -104,6 +103,8 @@ def __init__( self._postprocess: Optional[Postprocess] = postprocess self._serializer: Optional[Serializer] = None + self._data_pipeline_state: Optional[DataPipelineState] = None + # Explicitly set the serializer to call the setter self.serializer = serializer @@ -172,15 +173,9 @@ def predict( """ running_stage = RunningStage.PREDICTING - data_pipeline = self.build_data_pipeline(data_pipeline) - - if str(data_source) == data_source: - data_source = DefaultDataSource(data_source) - - if not isinstance(data_source, DataSource): - data_source = data_pipeline._preprocess_pipeline.data_source_of_type(data_source.as_type())() + data_pipeline = self.build_data_pipeline(data_source, data_pipeline) - x = [x for x in data_source.generate_dataset(x, running_stage, data_pipeline)] + x = [x for x in data_pipeline._data_source.generate_dataset(x, running_stage)] x = data_pipeline.worker_preprocessor(running_stage)(x) # switch to self.device when #7188 merge in Lightning x = self.transfer_batch_to_device(x, next(self.parameters()).device) @@ -262,7 +257,11 @@ def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]): serializer = SerializerMapping(serializer) self._serializer = serializer - def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]: + def build_data_pipeline( + self, + data_source: Optional[Union[str, DefaultDataSource, DataSource]] = None, + data_pipeline: Optional[DataPipeline] = None, + ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): @@ -279,10 +278,11 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O Returns: The fully resolved :class:`.DataPipeline`. """ - preprocess, postprocess, serializer = None, None, None + old_data_source, preprocess, postprocess, serializer = None, None, None, None # Datamodule if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None: + old_data_source = getattr(self.datamodule.data_pipeline, '_data_source', None) preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) @@ -290,6 +290,7 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O elif self.trainer is not None and hasattr( self.trainer, 'datamodule' ) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None: + old_data_source = getattr(self.trainer.datamodule.data_pipeline, '_data_source', None) preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) @@ -315,8 +316,19 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O getattr(data_pipeline, '_serializer', None), ) - data_pipeline = DataPipeline(preprocess, postprocess, serializer) - data_pipeline.initialize() + data_source = data_source or old_data_source + + if str(data_source) == data_source: + data_source = DefaultDataSource(data_source) + + if not isinstance(data_source, DataSource): + data_source = preprocess.data_source_of_type(data_source.as_type())() + + if old_data_source is not None: + data_source._state.update(old_data_source._state) # TODO: This is a hack + + data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) + self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) return data_pipeline @property diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 486ce68772..9f2cf8775c 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from inspect import signature -from typing import Any, Callable, Generic, Iterable, Iterator, Optional, Sequence, TYPE_CHECKING, TypeVar +from typing import Any, Generic, Iterable, Sequence, TYPE_CHECKING, TypeVar -import torch from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities.warning_utils import rank_zero_warn from torch.utils.data import Dataset, IterableDataset -from flash.data.callback import ControlFlow -from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext +from flash.data.utils import CurrentRunningStageFuncContext if TYPE_CHECKING: from flash.data.data_pipeline import DataPipeline @@ -44,13 +40,11 @@ def __init__( data: DATA_TYPE, data_source: 'DataSource', running_stage: RunningStage, - data_pipeline: Optional['DataPipeline'] = None, ) -> None: super().__init__() self.data = data self.data_source = data_source - self.data_pipeline = data_pipeline self._running_stage = None self.running_stage = running_stage @@ -61,6 +55,7 @@ def running_stage(self) -> RunningStage: @running_stage.setter def running_stage(self, running_stage: RunningStage) -> None: + from flash.data.data_pipeline import DataPipeline from flash.data.data_source import DataSource # Hack to avoid circular import TODO: something better than this self._running_stage = running_stage @@ -69,7 +64,7 @@ def running_stage(self, running_stage: RunningStage) -> None: self.load_sample = getattr( self.data_source, - self.data_pipeline._resolve_function_hierarchy( + DataPipeline._resolve_function_hierarchy( 'load_sample', self.data_source, self.running_stage, @@ -77,17 +72,6 @@ def running_stage(self, running_stage: RunningStage) -> None: ) ) - @property - def preprocess(self) -> Optional[Preprocess]: - if self.data_pipeline is not None: - return self.data_pipeline._preprocess_pipeline - - @property - def control_flow_callback(self) -> Optional[ControlFlow]: - preprocess = self.preprocess - if preprocess is not None: - return ControlFlow(preprocess.callbacks) - def _call_load_sample(self, sample: Any) -> Any: if self.load_sample: with self._load_sample_context: @@ -96,8 +80,6 @@ def _call_load_sample(self, sample: Any) -> Any: sample = self.load_sample(sample, self) else: sample = self.load_sample(sample) - if self.control_flow_callback: - self.control_flow_callback.on_load_sample(sample, self.running_stage) return sample diff --git a/flash/data/batch.py b/flash/data/batch.py index ea6ce1e9ca..739f4704ea 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -57,6 +57,8 @@ def __init__( self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess) def forward(self, sample: Any) -> Any: + self.callback.on_load_sample(sample, self.stage) + with self._current_stage_context: with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index c594fcdecb..a11147ca27 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -67,6 +67,7 @@ def __init__( super().__init__() + self._data_source: DataSource = data_source self._preprocess: Optional[Preprocess] = preprocess self._postprocess: Optional[Postprocess] = postprocess self._viz: Optional[BaseVisualization] = None @@ -75,7 +76,7 @@ def __init__( # TODO: Preprocess can change self.data_fetcher.attach_to_preprocess(self.preprocess) - train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets(self.data_pipeline) + train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets() self._train_ds = train_dataset self._val_ds = val_dataset @@ -297,6 +298,10 @@ def num_classes(self) -> Optional[int]: or getattr(self.test_dataset, "num_classes", None) ) + @property + def data_source(self) -> DataSource: + return self._data_source + @property def preprocess(self) -> Preprocess: return self._preprocess or self.preprocess_cls() @@ -307,7 +312,7 @@ def postprocess(self) -> Postprocess: @property def data_pipeline(self) -> DataPipeline: - return DataPipeline(self.preprocess, self.postprocess) + return DataPipeline(self.data_source, self.preprocess, self.postprocess) @staticmethod def _split_train_val( diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 89f152346d..40ff86f7df 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -24,8 +24,9 @@ from torch.utils.data import DataLoader, IterableDataset from torch.utils.data._utils.collate import default_collate, default_convert -from flash.data.auto_dataset import AutoDataset, IterableAutoDataset +from flash.data.auto_dataset import IterableAutoDataset from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential +from flash.data.data_source import DataSource from flash.data.process import DefaultPreprocess, Postprocess, Preprocess, ProcessState, Serializer from flash.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX @@ -88,10 +89,13 @@ class CustomPostprocess(Postprocess): def __init__( self, + data_source: Optional[DataSource] = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, serializer: Optional[Serializer] = None, ) -> None: + self._data_source = data_source + self._preprocess_pipeline = preprocess or DefaultPreprocess() self._postprocess_pipeline = postprocess or Postprocess() @@ -99,15 +103,18 @@ def __init__( self._running_stage = None - def initialize(self): + def initialize(self, data_pipeline_state: Optional[DataPipelineState]) -> DataPipelineState: """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.Preprocess`, :class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will give a warning.""" - data_pipeline_state = DataPipelineState() + data_pipeline_state = data_pipeline_state or DataPipelineState() + if self._data_source is not None: + self._data_source.attach_data_pipeline_state(data_pipeline_state) self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._serializer.attach_data_pipeline_state(data_pipeline_state) - data_pipeline_state._initialized = True + data_pipeline_state._initialized = True # TODO: Not sure we need this + return data_pipeline_state @staticmethod def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool: @@ -506,25 +513,6 @@ def _detach_postprocess_from_model(model: 'Task'): # if any other pipeline is attached which may rely on this! model.predict_step = model.predict_step._original - # def _generate_callable_auto_dataset( - # self, data: Union[Iterable, Any], running_stage: RunningStage = None - # ) -> Callable: - # - # def fn(): - # return self._generate_auto_dataset(data, running_stage=running_stage) - # - # return fn - - # def _generate_auto_dataset( - # self, - # data: Union[Iterable, Any], - # running_stage: RunningStage = None, - # use_iterable_auto_dataset: bool = False - # ) -> Union[AutoDataset, IterableAutoDataset]: - # if use_iterable_auto_dataset: - # return IterableAutoDataset(data, data_pipeline=self, running_stage=running_stage) - # return AutoDataset(data=data, data_pipeline=self, running_stage=running_stage) - def to_dataloader( self, data: Union[Iterable, Any], auto_collate: Optional[bool] = None, **loader_kwargs ) -> DataLoader: diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 314401637d..f3227aaa7d 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -13,17 +13,17 @@ # limitations under the License. import os import pathlib -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union from pytorch_lightning.trainer.states import RunningStage +from torch.nn import Module from torchvision.datasets.folder import has_file_allowed_extension, make_dataset from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset -from flash.data.data_pipeline import DataPipeline -from flash.data.process import Preprocess, ProcessState +from flash.data.process import ProcessState, Properties from flash.data.utils import _STAGES_PREFIX, CurrentRunningStageFuncContext @@ -53,7 +53,7 @@ def __setattr__(self, key, value): object.__setattr__(self, key, value) -class DataSource(ABC): +class DataSource(Properties, Module, ABC): def __init__( self, @@ -69,18 +69,9 @@ def __init__( self.test_data = test_data self.predict_data = predict_data - self._preprocess: Optional[Preprocess] = None - - @property - def preprocess(self) -> Optional[Preprocess]: - return self._preprocess - - @abstractmethod - def load_data( - self, - data: Any, - dataset: Optional[Any] = None - ) -> Iterable[Mapping[str, Any]]: # TODO: decide what type this should be + def load_data(self, + data: Any, + dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]], Iterable[Mapping[str, Any]]]: """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. Example:: @@ -91,44 +82,50 @@ def load_data( output: Mapping = load_data(data) """ + return data - @abstractmethod def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: """Loads single sample from dataset""" + return sample - def to_datasets(self, data_pipeline: DataPipeline) -> Tuple[Optional[BaseAutoDataset], ...]: - # attach preprocess - self._preprocess = data_pipeline._preprocess_pipeline - - train_dataset = self._generate_dataset_if_possible(RunningStage.TRAINING, data_pipeline) - val_dataset = self._generate_dataset_if_possible(RunningStage.VALIDATING, data_pipeline) - test_dataset = self._generate_dataset_if_possible(RunningStage.TESTING, data_pipeline) - predict_dataset = self._generate_dataset_if_possible(RunningStage.PREDICTING, data_pipeline) + def to_datasets(self) -> Tuple[Optional[BaseAutoDataset], ...]: + train_dataset = self._generate_dataset_if_possible(RunningStage.TRAINING) + val_dataset = self._generate_dataset_if_possible(RunningStage.VALIDATING) + test_dataset = self._generate_dataset_if_possible(RunningStage.TESTING) + predict_dataset = self._generate_dataset_if_possible(RunningStage.PREDICTING) return train_dataset, val_dataset, test_dataset, predict_dataset def _generate_dataset_if_possible( self, running_stage: RunningStage, - data_pipeline: DataPipeline, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: data = getattr(self, f"{_STAGES_PREFIX[running_stage]}_data", None) if data is not None: - return self.generate_dataset(data, running_stage, data_pipeline) + return self.generate_dataset(data, running_stage) def generate_dataset( self, data, running_stage: RunningStage, - data_pipeline: DataPipeline, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: + from flash.data.data_pipeline import DataPipeline + mock_dataset = MockDataset() with CurrentRunningStageFuncContext(running_stage, "load_data", self): - data = self.load_data(data, mock_dataset) # TODO: Should actually resolve this + load_data = getattr( + self, DataPipeline._resolve_function_hierarchy( + 'load_data', + self, + running_stage, + DataSource, + ) + ) + data = load_data(data, mock_dataset) if has_len(data): - dataset = AutoDataset(data, self, running_stage, data_pipeline) + dataset = AutoDataset(data, self, running_stage) else: - dataset = IterableAutoDataset(data, self, running_stage, data_pipeline) + dataset = IterableAutoDataset(data, self, running_stage) dataset.__dict__.update(mock_dataset.metadata) return dataset @@ -169,7 +166,7 @@ def __init__( self.labels = labels if self.labels is not None: - self.preprocess.set_state(LabelsState(self.labels)) + self.set_state(LabelsState(self.labels)) def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: inputs, targets = data @@ -224,7 +221,7 @@ def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mappin files, )] else: - self.preprocess.set_state(LabelsState(classes)) + self.set_state(LabelsState(classes)) dataset.num_classes = len(classes) data = make_dataset(data, class_to_idx, extensions=self.extensions) return [{'input': input, 'target': target} for input, target in data] diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 1b5b6d0dcd..7da6e2b82f 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -182,17 +182,17 @@ def pre_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) def to_tensor_transform(self, sample: Any) -> Any: - if self.current_transform == self._identity: - if isinstance(sample, (list, tuple)): - source, target = sample - if isinstance(source, torch.Tensor): - return source, target - return self.to_tensor(source), target - elif isinstance(sample, torch.Tensor): - return sample - return self.to_tensor(sample) - if isinstance(sample, torch.Tensor): - return sample + # if self.current_transform == self._identity: + # if isinstance(sample, (list, tuple)): + # source, target = sample + # if isinstance(source, torch.Tensor): + # return source, target + # return self.to_tensor(source), target + # elif isinstance(sample, torch.Tensor): + # return sample + # return self.to_tensor(sample) + # if isinstance(sample, torch.Tensor): + # return sample return self.common_step(sample) def post_tensor_transform(self, sample: Any) -> Any: diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 4a93ec1785..2ebc668f95 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -66,13 +66,9 @@ def fn_resnet(pretrained: bool = True): "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", ]) - print(predictions) -datamodule = ImageClassificationData.from_folders( - predict_folder="data/hymenoptera_data/predict/", - preprocess=model.preprocess, -) +datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/") # 7b. Or generate predictions with a whole folder! predictions = Trainer().predict(model, datamodule=datamodule) From 653057db1f19f46cb34160e57f8b53f1843214b4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 30 Apr 2021 19:15:25 +0100 Subject: [PATCH 06/78] Expand preprocess --- flash/data/data_module.py | 9 +- flash/data/process.py | 67 +++++++--- flash/vision/classification/data.py | 155 +++------------------- flash/vision/classification/transforms.py | 91 +++++++++++++ 4 files changed, 160 insertions(+), 162 deletions(-) create mode 100644 flash/vision/classification/transforms.py diff --git a/flash/data/data_module.py b/flash/data/data_module.py index a11147ca27..74260d8e0f 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -14,20 +14,17 @@ import os import pathlib import platform -from enum import Enum -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import pytorch_lightning as pl import torch -from datasets.splits import SplitInfo from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.nn import Module from torch.utils.data import DataLoader, Dataset from torch.utils.data.dataset import IterableDataset, Subset -from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset +from flash.data.auto_dataset import BaseAutoDataset, IterableAutoDataset from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess @@ -350,6 +347,7 @@ def from_folders( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + **kwargs: Any, ) -> 'DataModule': preprocess = preprocess or cls.preprocess_cls( @@ -357,6 +355,7 @@ def from_folders( val_transform, test_transform, predict_transform, + **kwargs, ) data_source = preprocess.data_source_of_type(FoldersDataSource)( diff --git a/flash/data/process.py b/flash/data/process.py index 4d4368712a..4d16f3154a 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -151,6 +151,9 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): pass +DATA_SOURCE_TYPE = TypeVar("DATA_SOURCE_TYPE") + + class Preprocess(BasePreprocess, Properties, Module): """ The :class:`~flash.data.process.Preprocess` encapsulates @@ -309,9 +312,17 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + data_sources: Optional[List[Type['DataSource']]] = None, ): super().__init__() + self.data_sources = data_sources or [] + + # resolve the default transforms + train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( + train_transform, val_transform, test_transform, predict_transform + ) + # used to keep track of provided transforms self._train_collate_in_worker_from_transform: Optional[bool] = None self._val_collate_in_worker_from_transform: Optional[bool] = None @@ -342,6 +353,33 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination['preprocess.state_dict'] = preprocess_state_dict return super()._save_to_state_dict(destination, prefix, keep_vars) + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return None + + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return None + + def _resolve_transforms( + self, + train_transform: Optional[Union[str, Dict]] = 'default', + val_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', + ): + if not train_transform or train_transform == 'default': + train_transform = self.default_train_transforms() + + if not val_transform or val_transform == 'default': + val_transform = self.default_val_transforms() + + if not test_transform or test_transform == 'default': + test_transform = self.default_val_transforms() + + if not predict_transform or predict_transform == 'default': + predict_transform = self.default_val_transforms() + + return train_transform, val_transform, test_transform, predict_transform + def _check_transforms(self, transform: Optional[Dict[str, Callable]], stage: RunningStage) -> Optional[Dict[str, Callable]]: if transform is None: @@ -430,15 +468,15 @@ def add_callbacks(self, callbacks: List['FlashCallback']): def pre_tensor_transform(self, sample: Any) -> Any: """Transforms to apply on a single object.""" - return sample + return self.current_transform(sample) def to_tensor_transform(self, sample: Any) -> Tensor: """Transforms to convert single object to a tensor.""" - return sample + return self.current_transform(sample) def post_tensor_transform(self, sample: Tensor) -> Tensor: """Transforms to apply on a tensor.""" - return sample + return self.current_transform(sample) def per_batch_transform(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). @@ -448,7 +486,7 @@ def per_batch_transform(self, batch: Any) -> Any: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are specified, uncollation has to be applied. """ - return batch + return self.current_transform(batch) def collate(self, samples: Sequence) -> Any: return default_collate(samples) @@ -466,7 +504,7 @@ def per_sample_transform_on_device(self, sample: Any) -> Any: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ - return sample + return self.current_transform(sample) def per_batch_transform_on_device(self, batch: Any) -> Any: """ @@ -477,25 +515,18 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ - return batch - - -T = TypeVar("T") - + return self.current_transform(batch) -class DefaultPreprocess(Preprocess): - - data_sources: List[Type['DataSource']] = [] # TODO: Make this a property - - # TODO: Doesn't need to be a classmethod - @classmethod - def data_source_of_type(cls, data_source_type: Type[T]) -> Optional[Type[T]]: - data_sources = cls.data_sources + def data_source_of_type(self, data_source_type: Type[DATA_SOURCE_TYPE]) -> Optional[Type[DATA_SOURCE_TYPE]]: + data_sources = self.data_sources for data_source in data_sources: if issubclass(data_source, data_source_type): return data_source return None + +class DefaultPreprocess(Preprocess): + def get_state_dict(self) -> Dict[str, Any]: return {} diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 7da6e2b82f..be3aa6415b 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -11,40 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch -import torchvision from PIL import Image from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import nn from torch.utils.data._utils.collate import default_collate -from torchvision import transforms as T from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule -from flash.data.process import DefaultPreprocess -from flash.data.transforms import ApplyToKeys -from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE +from flash.data.process import Preprocess +from flash.utils.imports import _MATPLOTLIB_AVAILABLE +from flash.vision.classification.transforms import default_train_transforms, default_val_transforms from flash.vision.data import ImageFilesDataSource, ImageFoldersDataSource -if _KORNIA_AVAILABLE: - import kornia as K - if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: plt = None -class ImageClassificationPreprocess(DefaultPreprocess): - - data_sources = [ImageFoldersDataSource, ImageFilesDataSource] - to_tensor = T.ToTensor() +class ImageClassificationPreprocess(Preprocess): def __init__( self, @@ -54,11 +44,15 @@ def __init__( predict_transform: Optional[Union[Dict[str, Callable]]] = None, image_size: Tuple[int, int] = (196, 196), ): - train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( - train_transform, val_transform, test_transform, predict_transform, image_size - ) self.image_size = image_size - super().__init__(train_transform, val_transform, test_transform, predict_transform) + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources=[ImageFoldersDataSource, ImageFilesDataSource], + ) def get_state_dict(self) -> Dict[str, Any]: return { @@ -73,97 +67,6 @@ def get_state_dict(self) -> Dict[str, Any]: def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): return cls(**state_dict) - def default_train_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: - if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": - # Better approach as all transforms are applied on tensor directly - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys('target', torch.as_tensor), - ), - "post_tensor_transform": ApplyToKeys( - 'input', - # TODO (Edgar): replace with resize once kornia is fixed - K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), - K.augmentation.RandomHorizontalFlip(), - ), - "per_batch_transform_on_device": ApplyToKeys( - 'input', - K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), - ) - } - else: - return { - "pre_tensor_transform": ApplyToKeys('input', T.Resize(image_size), T.RandomHorizontalFlip()), - "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys('target', torch.as_tensor), - ), - "post_tensor_transform": ApplyToKeys( - 'input', - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ), - } - - def default_val_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: - if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": - # Better approach as all transforms are applied on tensor directly - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys('target', torch.as_tensor), - ), - "post_tensor_transform": ApplyToKeys( - 'input', - # TODO (Edgar): replace with resize once kornia is fixed - K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), - ), - "per_batch_transform_on_device": ApplyToKeys( - 'input', - K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), - ) - } - else: - return { - "pre_tensor_transform": ApplyToKeys('input', T.Resize(image_size)), - "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys('target', torch.as_tensor), - ), - "post_tensor_transform": ApplyToKeys( - 'input', - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ), - } - - def _resolve_transforms( - self, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', - image_size: Tuple[int, int] = (196, 196), - ): - - if not train_transform or train_transform == 'default': - train_transform = self.default_train_transforms(image_size) - - if not val_transform or val_transform == 'default': - val_transform = self.default_val_transforms(image_size) - - if not test_transform or test_transform == 'default': - test_transform = self.default_val_transforms(image_size) - - if not predict_transform or predict_transform == 'default': - predict_transform = self.default_val_transforms(image_size) - - return ( - train_transform, - val_transform, - test_transform, - predict_transform, - ) - def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: # todo: Kornia transforms add batch dimension which need to be removed for sample in samples: @@ -172,37 +75,11 @@ def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: sample[key] = sample[key].squeeze(0) return default_collate(samples) - def common_step(self, sample: Any) -> Any: - # if isinstance(sample, (list, tuple)): - # source, target = sample - # return self.current_transform(source), target - return self.current_transform(sample) - - def pre_tensor_transform(self, sample: Any) -> Any: - return self.common_step(sample) - - def to_tensor_transform(self, sample: Any) -> Any: - # if self.current_transform == self._identity: - # if isinstance(sample, (list, tuple)): - # source, target = sample - # if isinstance(source, torch.Tensor): - # return source, target - # return self.to_tensor(source), target - # elif isinstance(sample, torch.Tensor): - # return sample - # return self.to_tensor(sample) - # if isinstance(sample, torch.Tensor): - # return sample - return self.common_step(sample) - - def post_tensor_transform(self, sample: Any) -> Any: - return self.common_step(sample) - - def per_batch_transform(self, sample: Any) -> Any: - return self.common_step(sample) + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return default_train_transforms(self.image_size) - def per_batch_transform_on_device(self, sample: Any) -> Any: - return self.common_step(sample) + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) class ImageClassificationData(DataModule): diff --git a/flash/vision/classification/transforms.py b/flash/vision/classification/transforms.py new file mode 100644 index 0000000000..90454524b3 --- /dev/null +++ b/flash/vision/classification/transforms.py @@ -0,0 +1,91 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable, Dict, Tuple + +import torch +import torchvision +from torch import nn +from torchvision import transforms as T + +from flash.data.transforms import ApplyToKeys +from flash.utils.imports import _KORNIA_AVAILABLE + +if _KORNIA_AVAILABLE: + import kornia as K + + +def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": + # Better approach as all transforms are applied on tensor directly + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys('target', torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + 'input', + # TODO (Edgar): replace with resize once kornia is fixed + K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), + K.augmentation.RandomHorizontalFlip(), + ), + "per_batch_transform_on_device": ApplyToKeys( + 'input', + K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + ) + } + else: + return { + "pre_tensor_transform": ApplyToKeys('input', T.Resize(image_size), T.RandomHorizontalFlip()), + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys('target', torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + 'input', + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ), + } + + +def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": + # Better approach as all transforms are applied on tensor directly + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys('target', torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + 'input', + # TODO (Edgar): replace with resize once kornia is fixed + K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), + ), + "per_batch_transform_on_device": ApplyToKeys( + 'input', + K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + ) + } + else: + return { + "pre_tensor_transform": ApplyToKeys('input', T.Resize(image_size)), + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys('target', torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + 'input', + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ), + } From 0184332a4f0a53dcd8f00d5c9baab8bae84e04e9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 30 Apr 2021 19:38:03 +0100 Subject: [PATCH 07/78] Fixes --- flash/data/data_pipeline.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 40ff86f7df..0f8214d6c1 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -14,7 +14,7 @@ import functools import inspect import weakref -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING import torch from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader @@ -22,7 +22,7 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader, IterableDataset -from torch.utils.data._utils.collate import default_collate, default_convert +from torch.utils.data._utils.collate import default_collate from flash.data.auto_dataset import IterableAutoDataset from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential @@ -108,6 +108,7 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState]) -> DataPi :class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will give a warning.""" data_pipeline_state = data_pipeline_state or DataPipelineState() + data_pipeline_state._initialized = False if self._data_source is not None: self._data_source.attach_data_pipeline_state(data_pipeline_state) self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) @@ -513,27 +514,6 @@ def _detach_postprocess_from_model(model: 'Task'): # if any other pipeline is attached which may rely on this! model.predict_step = model.predict_step._original - def to_dataloader( - self, data: Union[Iterable, Any], auto_collate: Optional[bool] = None, **loader_kwargs - ) -> DataLoader: - if 'collate_fn' in loader_kwargs: - if auto_collate: - raise MisconfigurationException('auto_collate and collate_fn are mutually exclusive') - - else: - if auto_collate is None: - auto_collate = True - - collate_fn = self.worker_collate_fn - - if collate_fn: - loader_kwargs['collate_fn'] = collate_fn - - else: - loader_kwargs['collate_fn'] = default_collate if auto_collate else default_convert - - return DataLoader(self._generate_auto_dataset(data), **loader_kwargs) - def __str__(self) -> str: preprocess: Preprocess = self._preprocess_pipeline postprocess: Postprocess = self._postprocess_pipeline From 5172a06eb52c20174fd7a896c2ba86c97bf1462f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 30 Apr 2021 19:39:29 +0100 Subject: [PATCH 08/78] Fixes --- flash/data/transforms.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/flash/data/transforms.py b/flash/data/transforms.py index 3a0c12bd1b..eed5b640eb 100644 --- a/flash/data/transforms.py +++ b/flash/data/transforms.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Mapping, Sequence, Union from torch import nn From 5c3f597d8881ae6ddcbf7384cb2d2a28113957e1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 30 Apr 2021 20:00:49 +0100 Subject: [PATCH 09/78] Cleaning --- flash/data/data_module.py | 104 +++++++++++++++++++++++++--- flash/data/process.py | 20 +----- flash/vision/classification/data.py | 2 +- 3 files changed, 95 insertions(+), 31 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 74260d8e0f..453adf53ef 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -14,7 +14,7 @@ import os import pathlib import platform -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TypeVar, Union import numpy as np import pytorch_lightning as pl @@ -28,10 +28,12 @@ from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess -from flash.data.data_source import DataSource, FoldersDataSource +from flash.data.data_source import DataSource, FilesDataSource, FoldersDataSource from flash.data.splits import SplitDataset from flash.data.utils import _STAGES_PREFIX +DATA_SOURCE_TYPE = TypeVar("DATA_SOURCE_TYPE") + class DataModule(pl.LightningDataModule): """Basic DataModule class for all Flash tasks @@ -48,6 +50,7 @@ class DataModule(pl.LightningDataModule): or 0 for Darwin platform. """ + data_sources = [] preprocess_cls = DefaultPreprocess postprocess_cls = Postprocess @@ -332,12 +335,17 @@ def _split_train_val( return SplitDataset(train_dataset, train_indices), SplitDataset(train_dataset, val_indices) @classmethod - def from_folders( + def data_source_of_type(cls, data_source_type: Type[DATA_SOURCE_TYPE]) -> Optional[Type[DATA_SOURCE_TYPE]]: + data_sources = cls.data_sources + for data_source in data_sources: + if issubclass(data_source, data_source_type): + return data_source + return None + + @classmethod + def from_data_source( cls, - train_folder: Optional[Union[str, pathlib.Path]] = None, - val_folder: Optional[Union[str, pathlib.Path]] = None, - test_folder: Optional[Union[str, pathlib.Path]] = None, - predict_folder: Union[str, pathlib.Path] = None, + data_source: DataSource, train_transform: Optional[Union[str, Dict]] = 'default', val_transform: Optional[Union[str, Dict]] = 'default', test_transform: Optional[Union[str, Dict]] = 'default', @@ -349,7 +357,6 @@ def from_folders( num_workers: Optional[int] = None, **kwargs: Any, ) -> 'DataModule': - preprocess = preprocess or cls.preprocess_cls( train_transform, val_transform, @@ -358,17 +365,92 @@ def from_folders( **kwargs, ) - data_source = preprocess.data_source_of_type(FoldersDataSource)( + return cls( + data_source, + preprocess, + data_fetcher=data_fetcher, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + ) + + @classmethod + def from_folders( + cls, + train_folder: Optional[Union[str, pathlib.Path]] = None, + val_folder: Optional[Union[str, pathlib.Path]] = None, + test_folder: Optional[Union[str, pathlib.Path]] = None, + predict_folder: Union[str, pathlib.Path] = None, + train_transform: Optional[Union[str, Dict]] = 'default', + val_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **kwargs: Any, + ) -> 'DataModule': + data_source = cls.data_source_of_type(FoldersDataSource)( train_folder=train_folder, val_folder=val_folder, test_folder=test_folder, predict_folder=predict_folder, ) - return cls( + return cls.from_data_source( data_source, - preprocess, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + ) + + @classmethod + def from_files( + cls, + train_files: Optional[Sequence[Union[str, pathlib.Path]]] = None, + train_targets: Optional[Sequence[Any]] = None, + val_files: Optional[Sequence[Union[str, pathlib.Path]]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_files: Optional[Sequence[Union[str, pathlib.Path]]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_files: Optional[Sequence[Union[str, pathlib.Path]]] = None, + train_transform: Optional[Union[str, Dict]] = 'default', + val_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **kwargs: Any, + ) -> 'DataModule': + data_source = cls.data_source_of_type(FilesDataSource)( + train_files=train_files, + train_targets=train_targets, + val_files=val_files, + val_targets=val_targets, + test_files=test_files, + test_targets=test_targets, + predict_files=predict_files, + ) + + return cls.from_data_source( + data_source, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, data_fetcher=data_fetcher, + preprocess=preprocess, val_split=val_split, batch_size=batch_size, num_workers=num_workers, diff --git a/flash/data/process.py b/flash/data/process.py index 4d16f3154a..813717058a 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -11,13 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools import os -import subprocess -from abc import ABC, ABCMeta, abstractclassmethod, abstractmethod, abstractproperty, abstractstaticmethod +from abc import ABC, abstractclassmethod, abstractmethod from dataclasses import dataclass -from importlib import import_module -from operator import truediv from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, TYPE_CHECKING, TypeVar, Union import torch @@ -33,7 +29,6 @@ if TYPE_CHECKING: from flash.data.data_pipeline import DataPipelineState - from flash.data.data_source import DataSource @dataclass(unsafe_hash=True, frozen=True) @@ -151,9 +146,6 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): pass -DATA_SOURCE_TYPE = TypeVar("DATA_SOURCE_TYPE") - - class Preprocess(BasePreprocess, Properties, Module): """ The :class:`~flash.data.process.Preprocess` encapsulates @@ -312,12 +304,9 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_sources: Optional[List[Type['DataSource']]] = None, ): super().__init__() - self.data_sources = data_sources or [] - # resolve the default transforms train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( train_transform, val_transform, test_transform, predict_transform @@ -517,13 +506,6 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: """ return self.current_transform(batch) - def data_source_of_type(self, data_source_type: Type[DATA_SOURCE_TYPE]) -> Optional[Type[DATA_SOURCE_TYPE]]: - data_sources = self.data_sources - for data_source in data_sources: - if issubclass(data_source, data_source_type): - return data_source - return None - class DefaultPreprocess(Preprocess): diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index be3aa6415b..3b6f4d9716 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -64,7 +64,7 @@ def get_state_dict(self) -> Dict[str, Any]: } @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: From 44d70e16ae74f890376a2a52175ef9c43e2b0291 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 30 Apr 2021 20:11:46 +0100 Subject: [PATCH 10/78] Fixes --- flash/data/data_module.py | 17 +++-------------- flash/data/process.py | 14 ++++++++++++++ flash/vision/classification/data.py | 3 ++- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 453adf53ef..ae2c493400 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -14,7 +14,7 @@ import os import pathlib import platform -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import pytorch_lightning as pl @@ -32,8 +32,6 @@ from flash.data.splits import SplitDataset from flash.data.utils import _STAGES_PREFIX -DATA_SOURCE_TYPE = TypeVar("DATA_SOURCE_TYPE") - class DataModule(pl.LightningDataModule): """Basic DataModule class for all Flash tasks @@ -50,7 +48,6 @@ class DataModule(pl.LightningDataModule): or 0 for Darwin platform. """ - data_sources = [] preprocess_cls = DefaultPreprocess postprocess_cls = Postprocess @@ -334,14 +331,6 @@ def _split_train_val( train_indices = [i for i in range(train_num_samples) if i not in val_indices] return SplitDataset(train_dataset, train_indices), SplitDataset(train_dataset, val_indices) - @classmethod - def data_source_of_type(cls, data_source_type: Type[DATA_SOURCE_TYPE]) -> Optional[Type[DATA_SOURCE_TYPE]]: - data_sources = cls.data_sources - for data_source in data_sources: - if issubclass(data_source, data_source_type): - return data_source - return None - @classmethod def from_data_source( cls, @@ -392,7 +381,7 @@ def from_folders( num_workers: Optional[int] = None, **kwargs: Any, ) -> 'DataModule': - data_source = cls.data_source_of_type(FoldersDataSource)( + data_source = (preprocess or cls.preprocess_cls).data_source_of_type(FoldersDataSource)( train_folder=train_folder, val_folder=val_folder, test_folder=test_folder, @@ -433,7 +422,7 @@ def from_files( num_workers: Optional[int] = None, **kwargs: Any, ) -> 'DataModule': - data_source = cls.data_source_of_type(FilesDataSource)( + data_source = (preprocess or cls.preprocess_cls).data_source_of_type(FilesDataSource)( train_files=train_files, train_targets=train_targets, val_files=val_files, diff --git a/flash/data/process.py b/flash/data/process.py index 813717058a..740dcc2f4c 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: from flash.data.data_pipeline import DataPipelineState + from flash.data.data_source import DataSource @dataclass(unsafe_hash=True, frozen=True) @@ -146,6 +147,9 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): pass +DATA_SOURCE_TYPE = TypeVar("DATA_SOURCE_TYPE") + + class Preprocess(BasePreprocess, Properties, Module): """ The :class:`~flash.data.process.Preprocess` encapsulates @@ -298,6 +302,8 @@ def load_data(cls, path_to_data: str) -> Iterable: """ + data_sources: Optional[List[Type['DataSource']]] + def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -506,6 +512,14 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: """ return self.current_transform(batch) + @classmethod + def data_source_of_type(cls, data_source_type: Type[DATA_SOURCE_TYPE]) -> Optional[Type[DATA_SOURCE_TYPE]]: + data_sources = cls.data_sources + for data_source in data_sources: + if issubclass(data_source, data_source_type): + return data_source + return None + class DefaultPreprocess(Preprocess): diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 3b6f4d9716..546e8bbf0b 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -36,6 +36,8 @@ class ImageClassificationPreprocess(Preprocess): + data_sources = [ImageFoldersDataSource, ImageFilesDataSource] + def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, @@ -51,7 +53,6 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources=[ImageFoldersDataSource, ImageFilesDataSource], ) def get_state_dict(self) -> Dict[str, Any]: From 08657ea8e97e1dac7102aed5bd5f7b173369a07f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 4 May 2021 11:45:04 +0100 Subject: [PATCH 11/78] Remove un-needed code --- flash/core/model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 4d929f5cd1..1ac302c44c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -324,9 +324,6 @@ def build_data_pipeline( if not isinstance(data_source, DataSource): data_source = preprocess.data_source_of_type(data_source.as_type())() - if old_data_source is not None: - data_source._state.update(old_data_source._state) # TODO: This is a hack - data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) return data_pipeline From 73be79218672079c9f5e3fdc4687adbfb465bcfb Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 4 May 2021 11:46:10 +0100 Subject: [PATCH 12/78] Remove sequence data source --- flash/data/data_source.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index f3227aaa7d..f8d4053b3d 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -143,38 +143,6 @@ def as_type(self) -> Type[DataSource]: return _data_source_types[self] -class SequenceDataSource(DataSource, ABC): - - def __init__( - self, - train_inputs: Optional[Sequence[Any]] = None, - train_targets: Optional[Sequence[Any]] = None, - val_inputs: Optional[Sequence[Any]] = None, - val_targets: Optional[Sequence[Any]] = None, - test_inputs: Optional[Sequence[Any]] = None, - test_targets: Optional[Sequence[Any]] = None, - predict_inputs: Optional[Sequence[Any]] = None, - labels: Optional[Sequence[str]] = None - ): - super().__init__( - train_data=(train_inputs, train_targets), - val_data=(val_inputs, val_targets), - test_data=(test_inputs, test_targets), - predict_data=(predict_inputs, None), - ) - - self.labels = labels - - if self.labels is not None: - self.set_state(LabelsState(self.labels)) - - def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: - inputs, targets = data - if targets is None: - return [{'input': input} for input in inputs] - return [{'input': input, 'target': target} for input, target in zip(inputs, targets)] - - class FoldersDataSource(DataSource, ABC): def __init__( From 338184003943249a011dbf7055d386a18bb87442 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 4 May 2021 12:23:40 +0100 Subject: [PATCH 13/78] Simplify data source --- flash/data/data_module.py | 52 +++++++++------ flash/data/data_source.py | 129 ++++++++++++++------------------------ flash/vision/data.py | 40 ++---------- 3 files changed, 84 insertions(+), 137 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index ae2c493400..673da29d6e 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -53,7 +53,11 @@ class DataModule(pl.LightningDataModule): def __init__( self, - data_source: DataSource, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None, + predict_dataset: Optional[Dataset] = None, + data_source: Optional[DataSource] = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, data_fetcher: Optional[BaseDataFetcher] = None, @@ -73,8 +77,6 @@ def __init__( # TODO: Preprocess can change self.data_fetcher.attach_to_preprocess(self.preprocess) - train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets() - self._train_ds = train_dataset self._val_ds = val_dataset self._test_ds = test_dataset @@ -335,6 +337,10 @@ def _split_train_val( def from_data_source( cls, data_source: DataSource, + train_data: Any = None, + val_data: Any = None, + test_data: Any = None, + predict_data: Any = None, train_transform: Optional[Union[str, Dict]] = 'default', val_transform: Optional[Union[str, Dict]] = 'default', test_transform: Optional[Union[str, Dict]] = 'default', @@ -354,9 +360,20 @@ def from_data_source( **kwargs, ) + train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets( + train_data, + val_data, + test_data, + predict_data, + ) + return cls( - data_source, - preprocess, + train_dataset, + val_dataset, + test_dataset, + predict_dataset, + data_source=data_source, + preprocess=preprocess, data_fetcher=data_fetcher, val_split=val_split, batch_size=batch_size, @@ -381,15 +398,14 @@ def from_folders( num_workers: Optional[int] = None, **kwargs: Any, ) -> 'DataModule': - data_source = (preprocess or cls.preprocess_cls).data_source_of_type(FoldersDataSource)( - train_folder=train_folder, - val_folder=val_folder, - test_folder=test_folder, - predict_folder=predict_folder, - ) + data_source = (preprocess or cls.preprocess_cls).data_source_of_type(FoldersDataSource)() return cls.from_data_source( data_source, + train_folder, + val_folder, + test_folder, + predict_folder, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, @@ -422,18 +438,14 @@ def from_files( num_workers: Optional[int] = None, **kwargs: Any, ) -> 'DataModule': - data_source = (preprocess or cls.preprocess_cls).data_source_of_type(FilesDataSource)( - train_files=train_files, - train_targets=train_targets, - val_files=val_files, - val_targets=val_targets, - test_files=test_files, - test_targets=test_targets, - predict_files=predict_files, - ) + data_source = (preprocess or cls.preprocess_cls).data_source_of_type(FilesDataSource)() return cls.from_data_source( data_source, + (train_files, train_targets), + (val_files, val_targets), + (test_files, test_targets), + predict_files, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, diff --git a/flash/data/data_source.py b/flash/data/data_source.py index f8d4053b3d..5f2a6c29df 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -16,7 +16,7 @@ from abc import ABC from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from pytorch_lightning.trainer.states import RunningStage from torch.nn import Module @@ -53,24 +53,13 @@ def __setattr__(self, key, value): object.__setattr__(self, key, value) -class DataSource(Properties, Module, ABC): +DATA_TYPE = TypeVar('DATA_TYPE') - def __init__( - self, - train_data: Optional[Any] = None, - val_data: Optional[Any] = None, - test_data: Optional[Any] = None, - predict_data: Optional[Any] = None, - ): - super().__init__() - self.train_data = train_data - self.val_data = val_data - self.test_data = test_data - self.predict_data = predict_data +class DataSource(Generic[DATA_TYPE], Properties, Module, ABC): def load_data(self, - data: Any, + data: DATA_TYPE, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]], Iterable[Mapping[str, Any]]]: """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. @@ -88,46 +77,45 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) """Loads single sample from dataset""" return sample - def to_datasets(self) -> Tuple[Optional[BaseAutoDataset], ...]: - train_dataset = self._generate_dataset_if_possible(RunningStage.TRAINING) - val_dataset = self._generate_dataset_if_possible(RunningStage.VALIDATING) - test_dataset = self._generate_dataset_if_possible(RunningStage.TESTING) - predict_dataset = self._generate_dataset_if_possible(RunningStage.PREDICTING) - return train_dataset, val_dataset, test_dataset, predict_dataset - - def _generate_dataset_if_possible( + def to_datasets( self, - running_stage: RunningStage, - ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: - data = getattr(self, f"{_STAGES_PREFIX[running_stage]}_data", None) - if data is not None: - return self.generate_dataset(data, running_stage) + train_data: Optional[DATA_TYPE] = None, + val_data: Optional[DATA_TYPE] = None, + test_data: Optional[DATA_TYPE] = None, + predict_data: Optional[DATA_TYPE] = None, + ) -> Tuple[Optional[BaseAutoDataset], ...]: + train_dataset = self.generate_dataset(train_data, RunningStage.TRAINING) + val_dataset = self.generate_dataset(val_data, RunningStage.VALIDATING) + test_dataset = self.generate_dataset(test_data, RunningStage.TESTING) + predict_dataset = self.generate_dataset(predict_data, RunningStage.PREDICTING) + return train_dataset, val_dataset, test_dataset, predict_dataset def generate_dataset( self, - data, + data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: - from flash.data.data_pipeline import DataPipeline - - mock_dataset = MockDataset() - with CurrentRunningStageFuncContext(running_stage, "load_data", self): - load_data = getattr( - self, DataPipeline._resolve_function_hierarchy( - 'load_data', - self, - running_stage, - DataSource, + if data is not None: + from flash.data.data_pipeline import DataPipeline + + mock_dataset = MockDataset() + with CurrentRunningStageFuncContext(running_stage, "load_data", self): + load_data = getattr( + self, DataPipeline._resolve_function_hierarchy( + 'load_data', + self, + running_stage, + DataSource, + ) ) - ) - data = load_data(data, mock_dataset) + data = load_data(data, mock_dataset) - if has_len(data): - dataset = AutoDataset(data, self, running_stage) - else: - dataset = IterableAutoDataset(data, self, running_stage) - dataset.__dict__.update(mock_dataset.metadata) - return dataset + if has_len(data): + dataset = AutoDataset(data, self, running_stage) + else: + dataset = IterableAutoDataset(data, self, running_stage) + dataset.__dict__.update(mock_dataset.metadata) + return dataset class DefaultDataSource(Enum): # TODO: This could be replaced with a data source registry that the user can add to @@ -143,22 +131,10 @@ def as_type(self) -> Type[DataSource]: return _data_source_types[self] -class FoldersDataSource(DataSource, ABC): +class FoldersDataSource(DataSource[str], ABC): - def __init__( - self, - train_folder: Optional[Union[str, pathlib.Path, list]] = None, - val_folder: Optional[Union[str, pathlib.Path, list]] = None, - test_folder: Optional[Union[str, pathlib.Path, list]] = None, - predict_folder: Optional[Union[str, pathlib.Path, list]] = None, - extensions: Optional[Tuple[str, ...]] = None, - ): - super().__init__( - train_data=train_folder, - val_data=val_folder, - test_data=test_folder, - predict_data=predict_folder, - ) + def __init__(self, extensions: Optional[Tuple[str, ...]] = None): + super().__init__() self.extensions = extensions @@ -178,7 +154,7 @@ def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx - def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + def load_data(self, data: str, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: classes, class_to_idx = self.find_classes(data) if not classes: files = [os.path.join(data, file) for file in os.listdir(data)] @@ -195,29 +171,18 @@ def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mappin return [{'input': input, 'target': target} for input, target in data] -class FilesDataSource(DataSource, ABC): +class FilesDataSource(DataSource[Tuple[Sequence[str], Optional[Sequence[Any]]]], ABC): - def __init__( - self, - train_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, - train_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, - val_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, - val_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, - test_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, - test_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, - predict_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, - extensions: Optional[Tuple[str, ...]] = None, - ): - super().__init__( - train_data=(train_files, train_targets), - val_data=(val_files, val_targets), - test_data=(test_files, test_targets), - predict_data=(predict_files, None), - ) + def __init__(self, extensions: Optional[Tuple[str, ...]] = None): + super().__init__() self.extensions = extensions - def load_data(self, data: Any, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + def load_data( + self, + data: Tuple[Sequence[str], Optional[Sequence[Any]]], + dataset: Optional[Any] = None, + ) -> Iterable[Mapping[str, Any]]: # TODO: Bring back the code to work out how many classes there are if isinstance(data, tuple): files, targets = data diff --git a/flash/vision/data.py b/flash/vision/data.py index f41c05c6c0..feffc211b4 100644 --- a/flash/vision/data.py +++ b/flash/vision/data.py @@ -4,25 +4,13 @@ from PIL import Image from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS -from flash.data.data_source import FilesDataSource, FoldersDataSource, SequenceDataSource +from flash.data.data_source import FilesDataSource, FoldersDataSource class ImageFoldersDataSource(FoldersDataSource): - def __init__( - self, - train_folder: Optional[Union[str, pathlib.Path, list]] = None, - val_folder: Optional[Union[str, pathlib.Path, list]] = None, - test_folder: Optional[Union[str, pathlib.Path, list]] = None, - predict_folder: Optional[Union[str, pathlib.Path, list]] = None, - ): - super().__init__( - train_folder=train_folder, - val_folder=val_folder, - test_folder=test_folder, - predict_folder=predict_folder, - extensions=IMG_EXTENSIONS, - ) + def __init__(self): + super().__init__(extensions=IMG_EXTENSIONS) def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: result = {} # TODO: this is required to avoid a memory leak, can we automate this? @@ -33,26 +21,8 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) class ImageFilesDataSource(FilesDataSource): - def __init__( - self, - train_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, - train_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, - val_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, - val_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, - test_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, - test_targets: Optional[Union[Sequence[Any], Iterable[Any]]] = None, - predict_files: Optional[Union[Sequence[Union[str, pathlib.Path]], Iterable[Union[str, pathlib.Path]]]] = None, - ): - super().__init__( # TODO: This feels like it can be simplified - train_files=train_files, - train_targets=train_targets, - val_files=val_files, - val_targets=val_targets, - test_files=test_files, - test_targets=test_targets, - predict_files=predict_files, - extensions=IMG_EXTENSIONS - ) + def __init__(self): + super().__init__(extensions=IMG_EXTENSIONS) def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: result = {} # TODO: this is required to avoid a memory leak, can we automate this? From e01987d2d519e9f80d6ff6d619e14ea5b5ad0eab Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 4 May 2021 12:35:48 +0100 Subject: [PATCH 14/78] Fix FilesDataSource --- flash/data/data_module.py | 2 +- flash/data/data_source.py | 21 ++++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 673da29d6e..28ec140cfd 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -298,7 +298,7 @@ def num_classes(self) -> Optional[int]: ) @property - def data_source(self) -> DataSource: + def data_source(self) -> Optional[DataSource]: return self._data_source @property diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 5f2a6c29df..c001c226d0 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -16,6 +16,7 @@ from abc import ABC from dataclasses import dataclass from enum import Enum +from inspect import signature from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from pytorch_lightning.trainer.states import RunningStage @@ -108,7 +109,11 @@ def generate_dataset( DataSource, ) ) - data = load_data(data, mock_dataset) + parameters = signature(load_data).parameters + if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before + data = load_data(data, mock_dataset) + else: + data = load_data(data) if has_len(data): dataset = AutoDataset(data, self, running_stage) @@ -184,10 +189,12 @@ def load_data( dataset: Optional[Any] = None, ) -> Iterable[Mapping[str, Any]]: # TODO: Bring back the code to work out how many classes there are - if isinstance(data, tuple): - files, targets = data - else: - files, targets = data, None # TODO: Sort this out + files, targets = data if not targets: - return [{'input': input} for input in files] - return [{'input': file, 'target': target} for file, target in zip(files, targets)] + return self.predict_load_data(files) + filtered = filter(lambda file, _: has_file_allowed_extension(file, self.extensions), zip(files, targets)) + return [{'input': file, 'target': target} for file, target in filtered] + + def predict_load_data(self, data: Sequence[str]): + filtered = filter(lambda file: has_file_allowed_extension(file, self.extensions), data) + return [{'input': input} for input in filtered] From e385dfa7876c24678a68bc43f99c9f4ce8de326a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 4 May 2021 12:39:04 +0100 Subject: [PATCH 15/78] Minor fix --- flash/data/data_source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index c001c226d0..45c80c6268 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -187,7 +187,7 @@ def load_data( self, data: Tuple[Sequence[str], Optional[Sequence[Any]]], dataset: Optional[Any] = None, - ) -> Iterable[Mapping[str, Any]]: + ) -> Sequence[Mapping[str, Any]]: # TODO: Bring back the code to work out how many classes there are files, targets = data if not targets: @@ -195,6 +195,6 @@ def load_data( filtered = filter(lambda file, _: has_file_allowed_extension(file, self.extensions), zip(files, targets)) return [{'input': file, 'target': target} for file, target in filtered] - def predict_load_data(self, data: Sequence[str]): + def predict_load_data(self, data: Sequence[str]) -> Sequence[Mapping[str, Any]]: filtered = filter(lambda file: has_file_allowed_extension(file, self.extensions), data) return [{'input': input} for input in filtered] From dc90754410d1bb4bbd28c3da563f7b8299baf34c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 4 May 2021 14:12:36 +0100 Subject: [PATCH 16/78] Add numpy and tesnor data sources --- flash/data/data_module.py | 82 ++++++++++++++++++++++++++++- flash/data/data_source.py | 68 ++++++++++++++++++++---- flash/vision/classification/data.py | 4 +- flash/vision/data.py | 18 ++++++- 4 files changed, 157 insertions(+), 15 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 28ec140cfd..6a38728ad0 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -28,7 +28,7 @@ from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess -from flash.data.data_source import DataSource, FilesDataSource, FoldersDataSource +from flash.data.data_source import DataSource, FilesDataSource, FoldersDataSource, NumpyDataSource, TensorDataSource from flash.data.splits import SplitDataset from flash.data.utils import _STAGES_PREFIX @@ -456,3 +456,83 @@ def from_files( batch_size=batch_size, num_workers=num_workers, ) + + @classmethod + def from_tensors( + cls, + train_data: Optional[Collection[torch.Tensor]] = None, + train_targets: Optional[Collection[Any]] = None, + val_data: Optional[Collection[torch.Tensor]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_data: Optional[Collection[torch.Tensor]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_data: Optional[Collection[torch.Tensor]] = None, + train_transform: Optional[Union[str, Dict]] = 'default', + val_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **kwargs: Any, + ) -> 'DataModule': + data_source = (preprocess or cls.preprocess_cls).data_source_of_type(TensorDataSource)() + + return cls.from_data_source( + data_source, + (train_data, train_targets), + (val_data, val_targets), + (test_data, test_targets), + predict_data, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + ) + + @classmethod + def from_numpy( + cls, + train_data: Optional[Collection[np.ndarray]] = None, + train_targets: Optional[Collection[Any]] = None, + val_data: Optional[Collection[np.ndarray]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_data: Optional[Collection[np.ndarray]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_data: Optional[Collection[np.ndarray]] = None, + train_transform: Optional[Union[str, Dict]] = 'default', + val_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **kwargs: Any, + ) -> 'DataModule': + data_source = (preprocess or cls.preprocess_cls).data_source_of_type(NumpyDataSource)() + + return cls.from_data_source( + data_source, + (train_data, train_targets), + (val_data, val_targets), + (test_data, test_targets), + predict_data, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + ) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 45c80c6268..2a1a4e48df 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -19,6 +19,8 @@ from inspect import signature from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +import numpy as np +import torch from pytorch_lightning.trainer.states import RunningStage from torch.nn import Module from torchvision.datasets.folder import has_file_allowed_extension, make_dataset @@ -54,7 +56,7 @@ def __setattr__(self, key, value): object.__setattr__(self, key, value) -DATA_TYPE = TypeVar('DATA_TYPE') +DATA_TYPE = TypeVar("DATA_TYPE") class DataSource(Generic[DATA_TYPE], Properties, Module, ABC): @@ -176,11 +178,43 @@ def load_data(self, data: str, dataset: Optional[Any] = None) -> Iterable[Mappin return [{'input': input, 'target': target} for input, target in data] -class FilesDataSource(DataSource[Tuple[Sequence[str], Optional[Sequence[Any]]]], ABC): +SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE") - def __init__(self, extensions: Optional[Tuple[str, ...]] = None): + +class SequenceDataSource( + Generic[SEQUENCE_DATA_TYPE], + DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence[Any]]]], + ABC, +): + + def __init__(self, labels: Optional[Sequence[str]] = None): super().__init__() + self.labels = labels + + if self.labels is not None: + self.set_state(LabelsState(self.labels)) + + def load_data( + self, + data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence[Any]]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + # TODO: Bring back the code to work out how many classes there are + inputs, targets = data + if targets is None: + return self.predict_load_data(data) + return [{'input': input, 'target': target} for input, target in zip(inputs, targets)] + + def predict_load_data(self, data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapping[str, Any]]: + return [{'input': input} for input in data] + + +class FilesDataSource(SequenceDataSource[str], ABC): + + def __init__(self, extensions: Optional[Tuple[str, ...]] = None, labels: Optional[Sequence[str]] = None): + super().__init__(labels=labels) + self.extensions = extensions def load_data( @@ -188,13 +222,25 @@ def load_data( data: Tuple[Sequence[str], Optional[Sequence[Any]]], dataset: Optional[Any] = None, ) -> Sequence[Mapping[str, Any]]: - # TODO: Bring back the code to work out how many classes there are - files, targets = data - if not targets: - return self.predict_load_data(files) - filtered = filter(lambda file, _: has_file_allowed_extension(file, self.extensions), zip(files, targets)) - return [{'input': file, 'target': target} for file, target in filtered] + return list( + filter( + lambda sample: has_file_allowed_extension(sample["input"], self.extensions), + super().load_data(data, dataset), + ) + ) def predict_load_data(self, data: Sequence[str]) -> Sequence[Mapping[str, Any]]: - filtered = filter(lambda file: has_file_allowed_extension(file, self.extensions), data) - return [{'input': input} for input in filtered] + return list( + filter( + lambda sample: has_file_allowed_extension(sample["input"], self.extensions), + super().predict_load_data(data, dataset), + ) + ) + + +class TensorDataSource(SequenceDataSource[torch.Tensor], ABC): + """""" # TODO: Some docstring here + + +class NumpyDataSource(SequenceDataSource[np.ndarray], ABC): + """""" # TODO: Some docstring here diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 546e8bbf0b..0852a73ed5 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -26,7 +26,7 @@ from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE from flash.vision.classification.transforms import default_train_transforms, default_val_transforms -from flash.vision.data import ImageFilesDataSource, ImageFoldersDataSource +from flash.vision.data import ImageFilesDataSource, ImageFoldersDataSource, ImageNumpyDataSource, ImageTensorDataSource if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -36,7 +36,7 @@ class ImageClassificationPreprocess(Preprocess): - data_sources = [ImageFoldersDataSource, ImageFilesDataSource] + data_sources = [ImageFoldersDataSource, ImageFilesDataSource, ImageNumpyDataSource, ImageTensorDataSource] def __init__( self, diff --git a/flash/vision/data.py b/flash/vision/data.py index feffc211b4..541a463140 100644 --- a/flash/vision/data.py +++ b/flash/vision/data.py @@ -1,10 +1,12 @@ import pathlib from typing import Any, Iterable, Mapping, Optional, Sequence, Union +import torch from PIL import Image from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS +from torchvision.transforms.functional import to_pil_image -from flash.data.data_source import FilesDataSource, FoldersDataSource +from flash.data.data_source import FilesDataSource, FoldersDataSource, NumpyDataSource, TensorDataSource class ImageFoldersDataSource(FoldersDataSource): @@ -29,3 +31,17 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) result.update(sample) result['input'] = default_loader(sample['input']) return result + + +class ImageTensorDataSource(TensorDataSource): + + def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + result['input'] = to_pil_image(sample['input']) + return result + + +class ImageNumpyDataSource(NumpyDataSource): + + def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + result['input'] = to_pil_image(torch.from_numpy(sample['input'])) + return result From c437043374c14c640fe2ffc38fc651a4c58353db Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 4 May 2021 18:41:21 +0100 Subject: [PATCH 17/78] Fixes --- flash/data/data_module.py | 2 +- flash/data/data_source.py | 6 +++++- flash/vision/data.py | 14 +++++++------- flash_examples/predict/image_embedder.py | 2 +- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 6a38728ad0..30ee87c434 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -14,7 +14,7 @@ import os import pathlib import platform -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import pytorch_lightning as pl diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 2a1a4e48df..e0e9b280a6 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -129,11 +129,15 @@ class DefaultDataSource(Enum): # TODO: This could be replaced with a data sourc FOLDERS = "folders" FILES = "files" + NUMPY = "numpy" + TENSOR = "tensor" def as_type(self) -> Type[DataSource]: _data_source_types = { DefaultDataSource.FOLDERS: FoldersDataSource, DefaultDataSource.FILES: FilesDataSource, + DefaultDataSource.NUMPY: NumpyDataSource, + DefaultDataSource.TENSOR: TensorDataSource } return _data_source_types[self] @@ -233,7 +237,7 @@ def predict_load_data(self, data: Sequence[str]) -> Sequence[Mapping[str, Any]]: return list( filter( lambda sample: has_file_allowed_extension(sample["input"], self.extensions), - super().predict_load_data(data, dataset), + super().predict_load_data(data), ) ) diff --git a/flash/vision/data.py b/flash/vision/data.py index 541a463140..b1fd91044f 100644 --- a/flash/vision/data.py +++ b/flash/vision/data.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Iterable, Mapping, Optional, Sequence, Union +from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Union import torch from PIL import Image @@ -35,13 +35,13 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) class ImageTensorDataSource(TensorDataSource): - def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: - result['input'] = to_pil_image(sample['input']) - return result + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Any: + sample['input'] = to_pil_image(sample['input']) + return sample class ImageNumpyDataSource(NumpyDataSource): - def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: - result['input'] = to_pil_image(torch.from_numpy(sample['input'])) - return result + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Any: + sample['input'] = to_pil_image(torch.from_numpy(sample['input'])) + return sample diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index 04bb155361..54df44a736 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -33,7 +33,7 @@ random_image = torch.randn(1, 3, 244, 244) # 6. Generate an embedding from this random image. -embeddings = embedder.predict(random_image) +embeddings = embedder.predict(random_image, data_source="tensor") # 7. Print embeddings shape print(embeddings[0].shape) From b32ee341611ef03a316c0b47c6ac99506da5ab3e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 May 2021 09:28:39 +0100 Subject: [PATCH 18/78] Onboard object detection --- flash/core/model.py | 17 +- flash/data/auto_dataset.py | 1 + flash/data/data_source.py | 35 +-- flash/data/process.py | 11 +- flash/data/transforms.py | 2 +- flash/vision/classification/data.py | 8 +- flash/vision/classification/transforms.py | 33 +-- flash/vision/data.py | 33 ++- flash/vision/detection/data.py | 216 +++++++----------- flash/vision/detection/model.py | 6 +- flash/vision/detection/transforms.py | 38 +++ flash_examples/finetuning/object_detection.py | 4 +- .../predict/image_classification.py | 2 +- 13 files changed, 212 insertions(+), 194 deletions(-) create mode 100644 flash/vision/detection/transforms.py diff --git a/flash/core/model.py b/flash/core/model.py index 1ac302c44c..ffb0f70fe9 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -29,7 +29,7 @@ from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, DataPipelineState -from flash.data.data_source import DataSource, DefaultDataSource +from flash.data.data_source import DataSource, DefaultDataSources from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping @@ -157,7 +157,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: def predict( self, x: Any, - data_source: Union[str, DefaultDataSource, DataSource] = DefaultDataSource.FILES, + data_source: Union[str, DataSource] = DefaultDataSources.FILES, data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ @@ -259,7 +259,7 @@ def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]): def build_data_pipeline( self, - data_source: Optional[Union[str, DefaultDataSource, DataSource]] = None, + data_source: Optional[Union[str, DataSource]] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available @@ -318,11 +318,8 @@ def build_data_pipeline( data_source = data_source or old_data_source - if str(data_source) == data_source: - data_source = DefaultDataSource(data_source) - - if not isinstance(data_source, DataSource): - data_source = preprocess.data_source_of_type(data_source.as_type())() + if isinstance(data_source, str): + data_source = preprocess.data_source_of_name(data_source)() data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) @@ -392,12 +389,16 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # https://pytorch.org/docs/stable/notes/serialization.html if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: checkpoint['data_pipeline'] = self.data_pipeline + if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: + checkpoint['_data_pipeline_state'] = self._data_pipeline_state super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: super().on_load_checkpoint(checkpoint) if 'data_pipeline' in checkpoint: self.data_pipeline = checkpoint['data_pipeline'] + if '_data_pipeline_state' in checkpoint: + self._data_pipeline_state = checkpoint['_data_pipeline_state'] @classmethod def available_backbones(cls) -> List[str]: diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 9f2cf8775c..1073e4982d 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -73,6 +73,7 @@ def running_stage(self, running_stage: RunningStage) -> None: ) def _call_load_sample(self, sample: Any) -> Any: + sample = dict(**sample) if self.load_sample: with self._load_sample_context: parameters = signature(self.load_sample).parameters diff --git a/flash/data/data_source.py b/flash/data/data_source.py index e0e9b280a6..abff14004b 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -22,6 +22,7 @@ import numpy as np import torch from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.enums import LightningEnum from torch.nn import Module from torchvision.datasets.folder import has_file_allowed_extension, make_dataset @@ -125,21 +126,22 @@ def generate_dataset( return dataset -class DefaultDataSource(Enum): # TODO: This could be replaced with a data source registry that the user can add to +class DefaultDataSources(LightningEnum): FOLDERS = "folders" FILES = "files" NUMPY = "numpy" TENSOR = "tensor" - def as_type(self) -> Type[DataSource]: - _data_source_types = { - DefaultDataSource.FOLDERS: FoldersDataSource, - DefaultDataSource.FILES: FilesDataSource, - DefaultDataSource.NUMPY: NumpyDataSource, - DefaultDataSource.TENSOR: TensorDataSource - } - return _data_source_types[self] + +class DefaultDataKeys(LightningEnum): + + INPUT = "input" + TARGET = "target" + + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) class FoldersDataSource(DataSource[str], ABC): @@ -170,7 +172,7 @@ def load_data(self, data: str, dataset: Optional[Any] = None) -> Iterable[Mappin if not classes: files = [os.path.join(data, file) for file in os.listdir(data)] return [{ - 'input': file + DefaultDataKeys.INPUT: file } for file in filter( lambda file: has_file_allowed_extension(file, self.extensions), files, @@ -179,7 +181,7 @@ def load_data(self, data: str, dataset: Optional[Any] = None) -> Iterable[Mappin self.set_state(LabelsState(classes)) dataset.num_classes = len(classes) data = make_dataset(data, class_to_idx, extensions=self.extensions) - return [{'input': input, 'target': target} for input, target in data] + return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE") @@ -208,10 +210,13 @@ def load_data( inputs, targets = data if targets is None: return self.predict_load_data(data) - return [{'input': input, 'target': target} for input, target in zip(inputs, targets)] + return [{ + DefaultDataKeys.INPUT: input, + DefaultDataKeys.TARGET: target + } for input, target in zip(inputs, targets)] def predict_load_data(self, data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapping[str, Any]]: - return [{'input': input} for input in data] + return [{DefaultDataKeys.INPUT: input} for input in data] class FilesDataSource(SequenceDataSource[str], ABC): @@ -228,7 +233,7 @@ def load_data( ) -> Sequence[Mapping[str, Any]]: return list( filter( - lambda sample: has_file_allowed_extension(sample["input"], self.extensions), + lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), super().load_data(data, dataset), ) ) @@ -236,7 +241,7 @@ def load_data( def predict_load_data(self, data: Sequence[str]) -> Sequence[Mapping[str, Any]]: return list( filter( - lambda sample: has_file_allowed_extension(sample["input"], self.extensions), + lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), super().predict_load_data(data), ) ) diff --git a/flash/data/process.py b/flash/data/process.py index 740dcc2f4c..3b29175e6d 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -302,7 +302,7 @@ def load_data(cls, path_to_data: str) -> Iterable: """ - data_sources: Optional[List[Type['DataSource']]] + data_sources: Optional[Dict[str, Type['DataSource']]] def __init__( self, @@ -515,11 +515,18 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: @classmethod def data_source_of_type(cls, data_source_type: Type[DATA_SOURCE_TYPE]) -> Optional[Type[DATA_SOURCE_TYPE]]: data_sources = cls.data_sources - for data_source in data_sources: + for data_source in data_sources.values(): if issubclass(data_source, data_source_type): return data_source return None + @classmethod + def data_source_of_name(cls, data_source_name: str) -> Optional[Type[DATA_SOURCE_TYPE]]: + data_sources = cls.data_sources + if data_source_name in data_sources: + return data_sources[data_source_name] + return None + class DefaultPreprocess(Preprocess): diff --git a/flash/data/transforms.py b/flash/data/transforms.py index eed5b640eb..0a26224791 100644 --- a/flash/data/transforms.py +++ b/flash/data/transforms.py @@ -22,7 +22,7 @@ class ApplyToKeys(nn.Sequential): def __init__(self, keys: Union[str, Sequence[str]], *args): super().__init__(*[convert_to_modules(arg) for arg in args]) - if str(keys) == keys: + if isinstance(keys, str): keys = [keys] self.keys = keys diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 0852a73ed5..6781fb54ed 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -23,6 +23,7 @@ from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule +from flash.data.data_source import DefaultDataSources from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE from flash.vision.classification.transforms import default_train_transforms, default_val_transforms @@ -36,7 +37,12 @@ class ImageClassificationPreprocess(Preprocess): - data_sources = [ImageFoldersDataSource, ImageFilesDataSource, ImageNumpyDataSource, ImageTensorDataSource] + data_sources = { + DefaultDataSources.FOLDERS: ImageFoldersDataSource, + DefaultDataSources.FILES: ImageFilesDataSource, + DefaultDataSources.NUMPY: ImageNumpyDataSource, + DefaultDataSources.TENSOR: ImageTensorDataSource, + } def __init__( self, diff --git a/flash/vision/classification/transforms.py b/flash/vision/classification/transforms.py index 90454524b3..3eff2f4c2c 100644 --- a/flash/vision/classification/transforms.py +++ b/flash/vision/classification/transforms.py @@ -19,6 +19,7 @@ from torch import nn from torchvision import transforms as T +from flash.data.data_source import DefaultDataKeys from flash.data.transforms import ApplyToKeys from flash.utils.imports import _KORNIA_AVAILABLE @@ -31,29 +32,29 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] # Better approach as all transforms are applied on tensor directly return { "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys('target', torch.as_tensor), + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( - 'input', + DefaultDataKeys.INPUT, # TODO (Edgar): replace with resize once kornia is fixed K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), K.augmentation.RandomHorizontalFlip(), ), "per_batch_transform_on_device": ApplyToKeys( - 'input', + DefaultDataKeys.INPUT, K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ) } else: return { - "pre_tensor_transform": ApplyToKeys('input', T.Resize(image_size), T.RandomHorizontalFlip()), + "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size), T.RandomHorizontalFlip()), "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys('target', torch.as_tensor), + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( - 'input', + DefaultDataKeys.INPUT, T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ), } @@ -64,28 +65,28 @@ def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: # Better approach as all transforms are applied on tensor directly return { "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys('target', torch.as_tensor), + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( - 'input', + DefaultDataKeys.INPUT, # TODO (Edgar): replace with resize once kornia is fixed K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), ), "per_batch_transform_on_device": ApplyToKeys( - 'input', + DefaultDataKeys.INPUT, K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ) } else: return { - "pre_tensor_transform": ApplyToKeys('input', T.Resize(image_size)), + "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)), "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys('target', torch.as_tensor), + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( - 'input', + DefaultDataKeys.INPUT, T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ), } diff --git a/flash/vision/data.py b/flash/vision/data.py index b1fd91044f..9d32cd603a 100644 --- a/flash/vision/data.py +++ b/flash/vision/data.py @@ -1,12 +1,29 @@ -import pathlib -from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Union +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Mapping, Optional import torch -from PIL import Image from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS from torchvision.transforms.functional import to_pil_image -from flash.data.data_source import FilesDataSource, FoldersDataSource, NumpyDataSource, TensorDataSource +from flash.data.data_source import ( + DefaultDataKeys, + FilesDataSource, + FoldersDataSource, + NumpyDataSource, + TensorDataSource, +) class ImageFoldersDataSource(FoldersDataSource): @@ -17,7 +34,7 @@ def __init__(self): def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: result = {} # TODO: this is required to avoid a memory leak, can we automate this? result.update(sample) - result['input'] = default_loader(sample['input']) + result[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) return result @@ -29,19 +46,19 @@ def __init__(self): def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: result = {} # TODO: this is required to avoid a memory leak, can we automate this? result.update(sample) - result['input'] = default_loader(sample['input']) + result[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) return result class ImageTensorDataSource(TensorDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Any: - sample['input'] = to_pil_image(sample['input']) + sample[DefaultDataKeys.INPUT] = to_pil_image(sample[DefaultDataKeys.INPUT]) return sample class ImageNumpyDataSource(NumpyDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Any: - sample['input'] = to_pil_image(torch.from_numpy(sample['input'])) + sample[DefaultDataKeys.INPUT] = to_pil_image(torch.from_numpy(sample[DefaultDataKeys.INPUT])) return sample diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 35905d683b..4af7ed3de9 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -12,168 +12,109 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, Optional, Sequence, Tuple -import torch -from PIL import Image -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor, tensor -from torch._six import container_abcs +from pytorch_lightning.trainer.states import RunningStage from torch.nn import Module -from torch.utils.data._utils.collate import default_collate -from torchvision import transforms as T +from torchvision.datasets.folder import default_loader -from flash.data.auto_dataset import AutoDataset from flash.data.data_module import DataModule -from flash.data.process import DefaultPreprocess, Preprocess -from flash.data.utils import _contains_any_tensor +from flash.data.data_source import DataSource +from flash.data.process import Preprocess from flash.utils.imports import _COCO_AVAILABLE -from flash.vision.utils import pil_loader +from flash.vision.detection.transforms import default_transforms if _COCO_AVAILABLE: from pycocotools.coco import COCO -class CustomCOCODataset(torch.utils.data.Dataset): +class COCODataSource(DataSource[Tuple[str, str]]): - def __init__( - self, - root: str, - ann_file: str, - transforms: Optional[Callable] = None, - loader: Optional[Callable] = pil_loader, - ): - if not _COCO_AVAILABLE: - raise ImportError("Kindly install the COCO API `pycocotools` to use the Dataset") - - self.root = root - self.transforms = transforms - self.coco = COCO(ann_file) - self.ids = list(sorted(self.coco.imgs.keys())) - self.loader = loader - - @property - def num_classes(self) -> int: - categories = self.coco.loadCats(self.coco.getCatIds()) - if not categories: - raise ValueError("No Categories found") - return categories[-1]["id"] + 1 - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - coco = self.coco - img_idx = self.ids[index] - - ann_ids = coco.getAnnIds(imgIds=img_idx) - annotations = coco.loadAnns(ann_ids) - - image_path = coco.loadImgs(img_idx)[0]["file_name"] - img = Image.open(os.path.join(self.root, image_path)) - - boxes = [] - labels = [] - areas = [] - iscrowd = [] - - for obj in annotations: - xmin = obj["bbox"][0] - ymin = obj["bbox"][1] - xmax = xmin + obj["bbox"][2] - ymax = ymin + obj["bbox"][3] - - bbox = [xmin, ymin, xmax, ymax] - keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) - if keep: - boxes.append(bbox) - labels.append(obj["category_id"]) - areas.append(obj["area"]) - iscrowd.append(obj["iscrowd"]) - - target = dict( - boxes=torch.as_tensor(boxes, dtype=torch.float32), - labels=torch.as_tensor(labels, dtype=torch.int64), - image_id=tensor([img_idx]), - area=torch.as_tensor(areas, dtype=torch.float32), - iscrowd=torch.as_tensor(iscrowd, dtype=torch.int64) - ) + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root, ann_file = data - if self.transforms: - img = self.transforms(img) + coco = COCO(ann_file) - return img, target + categories = coco.loadCats(coco.getCatIds()) + if categories: + dataset.num_classes = categories[-1]["id"] + 1 - def __len__(self) -> int: - return len(self.ids) + img_ids = list(sorted(coco.imgs.keys())) + paths = coco.loadImgs(img_ids) + data = [] -def _coco_remove_images_without_annotations(dataset): - # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py + for img_id, path in zip(img_ids, paths): + path = path["file_name"] - def _has_only_empty_bbox(annot: List): - return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annot) + ann_ids = coco.getAnnIds(imgIds=img_id) + annotations = coco.loadAnns(ann_ids) - def _has_valid_annotation(annot: List): - # if it's empty, there is no annotation - if not annot: - return False - # if all boxes have close to zero area, there is no annotation - if _has_only_empty_bbox(annot): - return False - return True + boxes, labels, areas, iscrowd = [], [], [], [] - ids = [] - for ds_idx, img_id in enumerate(dataset.ids): - ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) - anno = dataset.coco.loadAnns(ann_ids) - if _has_valid_annotation(anno): - ids.append(ds_idx) + # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py + if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations): + continue - dataset = torch.utils.data.Subset(dataset, ids) - return dataset + for obj in annotations: + xmin = obj["bbox"][0] + ymin = obj["bbox"][1] + xmax = xmin + obj["bbox"][2] + ymax = ymin + obj["bbox"][3] + bbox = [xmin, ymin, xmax, ymax] + keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) + if keep: + boxes.append(bbox) + labels.append(obj["category_id"]) + areas.append(obj["area"]) + iscrowd.append(obj["iscrowd"]) -class ObjectDetectionPreprocess(DefaultPreprocess): + data.append( + dict( + input=os.path.join(root, path), + target=dict( + boxes=boxes, + labels=labels, + image_id=img_id, + area=areas, + iscrowd=iscrowd, + ) + ) + ) + return data - to_tensor = T.ToTensor() + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample['input'] = default_loader(sample['input']) + return sample - def load_data(self, metadata: Any, dataset: AutoDataset) -> CustomCOCODataset: - # Extract folder, coco annotation file and the transform to be applied on the images - folder, ann_file, transform = metadata - ds = CustomCOCODataset(folder, ann_file, transform) - if self.training: - dataset.num_classes = ds.num_classes - ds = _coco_remove_images_without_annotations(ds) - return ds - def predict_load_data(self, samples): - return samples +class ObjectDetectionPreprocess(Preprocess): - def pre_tensor_transform(self, samples: Any) -> Any: - if _contains_any_tensor(samples): - return samples + data_sources = { + "coco": COCODataSource, + } - if isinstance(samples, str): - samples = [samples] + def collate(self, samples: Any) -> Any: + return {key: [sample[key] for sample in samples] for key in samples[0]} - if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): - outputs = [] - for sample in samples: - outputs.append(pil_loader(sample)) - return outputs - raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.") + def get_state_dict(self) -> Dict[str, Any]: + return { + "train_transform": self._train_transform, + "val_transform": self._val_transform, + "test_transform": self._test_transform, + "predict_transform": self._predict_transform, + } - def to_tensor_transform(self, sample) -> Any: - return self.to_tensor(sample[0]), sample[1] + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) - def predict_to_tensor_transform(self, sample) -> Any: - return self.to_tensor(sample[0]) + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms() - def collate(self, samples: Any) -> Any: - if not isinstance(samples, Tensor): - elem = samples[0] - if isinstance(elem, container_abcs.Sequence): - return tuple(zip(*samples)) - return default_collate(samples) - return samples.unsqueeze(dim=0) + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms() class ObjectDetectionData(DataModule): @@ -192,7 +133,6 @@ def from_coco( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, test_transform: Optional[Dict[str, Module]] = None, - predict_transform: Optional[Dict[str, Module]] = None, batch_size: int = 4, num_workers: Optional[int] = None, preprocess: Preprocess = None, @@ -202,13 +142,15 @@ def from_coco( train_transform, val_transform, test_transform, - predict_transform, ) - return cls.from_load_data_inputs( - train_load_data_input=(train_folder, train_ann_file, train_transform), - val_load_data_input=(val_folder, val_ann_file, val_transform) if val_folder else None, - test_load_data_input=(test_folder, test_ann_file, test_transform) if test_folder else None, + data_source = preprocess.data_source_of_type(COCODataSource)() + + return cls.from_data_source( + data_source=data_source, + train_data=(train_folder, train_ann_file) if train_folder else None, + val_data=(val_folder, val_ann_file) if val_folder else None, + test_data=(test_folder, test_ann_file) if test_folder else None, batch_size=batch_size, num_workers=num_workers, preprocess=preprocess, diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index a7eed0e105..9204c094b2 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -156,7 +156,7 @@ def get_model( def training_step(self, batch, batch_idx) -> Any: """The training step. Overrides ``Task.training_step`` """ - images, targets = batch + images, targets = batch['input'], batch['target'] targets = [{k: v for k, v in t.items()} for t in targets] # fasterrcnn takes both images and targets for training, returns loss_dict @@ -166,7 +166,7 @@ def training_step(self, batch, batch_idx) -> Any: return loss def validation_step(self, batch, batch_idx): - images, targets = batch + images, targets = batch['input'], batch['target'] # fasterrcnn takes only images for eval() mode outs = self.model(images) iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() @@ -178,7 +178,7 @@ def validation_epoch_end(self, outs): return {"avg_val_iou": avg_iou, "log": logs} def test_step(self, batch, batch_idx): - images, targets = batch + images, targets = batch['input'], batch['target'] # fasterrcnn takes only images for eval() mode outs = self.model(images) iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() diff --git a/flash/vision/detection/transforms.py b/flash/vision/detection/transforms.py new file mode 100644 index 0000000000..735f9db305 --- /dev/null +++ b/flash/vision/detection/transforms.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict + +import torch +import torchvision +from torch import nn + +from flash.data.transforms import ApplyToKeys + + +def default_transforms() -> Dict[str, Callable]: + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys( + 'target', + nn.Sequential( + ApplyToKeys('boxes', torch.as_tensor), + ApplyToKeys('labels', torch.as_tensor), + ApplyToKeys('image_id', torch.as_tensor), + ApplyToKeys('area', torch.as_tensor), + ApplyToKeys('iscrowd', torch.as_tensor), + ) + ), + ), + } diff --git a/flash_examples/finetuning/object_detection.py b/flash_examples/finetuning/object_detection.py index 4d013c37ac..eee289cd0d 100644 --- a/flash_examples/finetuning/object_detection.py +++ b/flash_examples/finetuning/object_detection.py @@ -23,14 +23,14 @@ datamodule = ObjectDetectionData.from_coco( train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", - batch_size=2 + batch_size=2, ) # 3. Build the model model = ObjectDetector(num_classes=datamodule.num_classes) # 4. Create the trainer -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, limit_train_batches=1) # 5. Finetune the model trainer.finetune(model, datamodule) diff --git a/flash_examples/predict/image_classification.py b/flash_examples/predict/image_classification.py index fe697b2963..3f3f90fef7 100644 --- a/flash_examples/predict/image_classification.py +++ b/flash_examples/predict/image_classification.py @@ -19,7 +19,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the model from a checkpoint -model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") +model = ImageClassifier.load_from_checkpoint("../finetuning/image_classification_model.pt") # 3a. Predict what's on a few images! ants or bees? predictions = model.predict([ From bfd320d850723cb7e37b1b453563df43b928bd59 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 5 May 2021 13:23:17 +0100 Subject: [PATCH 19/78] update --- flash/core/classification.py | 7 ++----- flash/core/finetuning.py | 1 + flash/data/auto_dataset.py | 2 +- flash/data/data_pipeline.py | 8 +++++++- flash/tabular/classification/data/data.py | 4 ++-- flash/text/classification/data.py | 4 ++-- tests/core/test_finetuning.py | 2 +- tests/data/test_process.py | 4 ++-- 8 files changed, 18 insertions(+), 14 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index a716ae21e2..fb7b45a2f9 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Callable, List, Mapping, Optional, Sequence, Union import torch @@ -125,7 +124,7 @@ class Labels(Classes): Args: labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not - provided, will attempt to get them from the :class:`.ClassificationState`. + provided, will attempt to get them from the :class:`.LabelsState`. multi_label: If true, treats outputs as multi label logits. @@ -155,7 +154,5 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: return [labels[cls] for cls in classes] return labels[classes] else: - rank_zero_warn( - "No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning - ) + rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning) return classes diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 9acb79d3be..63eb209a00 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -50,6 +50,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo train_bn: Whether to train Batch Norm layer """ + super().__init__() self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names self.train_bn = train_bn diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 1073e4982d..1c468dfb37 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -55,7 +55,7 @@ def running_stage(self) -> RunningStage: @running_stage.setter def running_stage(self, running_stage: RunningStage) -> None: - from flash.data.data_pipeline import DataPipeline + from flash.data.data_pipeline import DataPipeline # noqa F811 from flash.data.data_source import DataSource # Hack to avoid circular import TODO: something better than this self._running_stage = running_stage diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 0f8214d6c1..edb12265f1 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -543,7 +543,13 @@ def __init__(self, func_to_wrap: Callable, model: 'Task') -> None: def __call__(self, *args, **kwargs): outputs = self.func(*args, **kwargs) - internal_running_state = self.internal_mapping[self.model.trainer._running_stage] + # todo (tchaton) Remove this check + try: + stage = self.model.trainer._running_stage + except AttributeError: + stage = self.model.trainer.state.stage + + internal_running_state = self.internal_mapping[stage] additional_func = self._stage_mapping.get(internal_running_state, None) if additional_func: diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index ec40abc82d..87e7751bff 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -20,7 +20,7 @@ from sklearn.model_selection import train_test_split from torch.utils.data import Dataset -from flash.core.classification import ClassificationState +from flash.core.classification import LabelsState from flash.data.auto_dataset import AutoDataset from flash.data.data_module import DataModule from flash.data.process import Preprocess @@ -49,7 +49,7 @@ def __init__( is_regression: bool, ): super().__init__() - self.set_state(ClassificationState(classes)) + self.set_state(LabelsState(classes)) self.cat_cols = cat_cols self.num_cols = num_cols diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index d836cb5552..3c65dfcc0f 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -21,7 +21,7 @@ from transformers import AutoTokenizer, default_data_collator from transformers.modeling_outputs import SequenceClassifierOutput -from flash.core.classification import ClassificationState +from flash.core.classification import LabelsState from flash.data.auto_dataset import AutoDataset from flash.data.data_module import DataModule from flash.data.process import Postprocess, Preprocess @@ -81,7 +81,7 @@ def __init__( class_to_label_mapping = ['CLASS_UNKNOWN'] * (max(self.label_to_class_mapping.values()) + 1) for label, cls in self.label_to_class_mapping.items(): class_to_label_mapping[cls] = label - self.set_state(ClassificationState(class_to_label_mapping)) + self.set_state(LabelsState(class_to_label_mapping)) def get_state_dict(self) -> Dict[str, Any]: return { diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index df12d85ca2..e86838b558 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -25,7 +25,7 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index: int) -> Any: - return torch.rand(3, 64, 64), torch.randint(10, size=(1, )).item() + return {"input": torch.rand(3, 64, 64), "target": torch.randint(10, size=(1, )).item()} def __len__(self) -> int: return 100 diff --git a/tests/data/test_process.py b/tests/data/test_process.py index 6dcd9e8f97..efbbd82d2c 100644 --- a/tests/data/test_process.py +++ b/tests/data/test_process.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader from flash import Task, Trainer -from flash.core.classification import ClassificationState, Labels +from flash.core.classification import Labels, LabelsState from flash.data.data_pipeline import DataPipeline, DataPipelineState, DefaultPreprocess from flash.data.process import ProcessState, Properties, Serializer, SerializerMapping @@ -129,4 +129,4 @@ def __init__(self): trainer.save_checkpoint(checkpoint_file) model = CustomModel.load_from_checkpoint(checkpoint_file) assert isinstance(model.preprocess._data_pipeline_state, DataPipelineState) - assert model.preprocess._data_pipeline_state._state[ClassificationState] == ClassificationState(['a', 'b']) + assert model.preprocess._data_pipeline_state._state[LabelsState] == LabelsState(['a', 'b']) From 7e050becddabff648750ff8d37ce0eb6ed162b34 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 May 2021 14:12:42 +0100 Subject: [PATCH 20/78] Add text classification --- flash/core/model.py | 4 +- flash/data/data_module.py | 77 ++-- flash/data/data_source.py | 25 +- flash/data/process.py | 2 + flash/text/classification/data.py | 335 ++++++------------ flash/vision/classification/model.py | 17 + .../finetuning/text_classification.py | 6 +- flash_examples/predict/text_classification.py | 14 +- 8 files changed, 218 insertions(+), 262 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index ffb0f70fe9..3852148395 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -112,8 +112,7 @@ def step(self, batch: Any, batch_idx: int) -> Any: """ The training/validation/test step. Override for custom behavior. """ - x, y = batch['input'], batch['target'] - # x, y = batch + x, y = batch y_hat = self(x) output = {"y_hat": y_hat} losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} @@ -185,7 +184,6 @@ def predict( return predictions def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = batch['input'] if isinstance(batch, tuple): batch = batch[0] elif isinstance(batch, list): diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 30ee87c434..13a43b6301 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -28,7 +28,7 @@ from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess -from flash.data.data_source import DataSource, FilesDataSource, FoldersDataSource, NumpyDataSource, TensorDataSource +from flash.data.data_source import DataSource, DefaultDataSources from flash.data.splits import SplitDataset from flash.data.utils import _STAGES_PREFIX @@ -336,7 +336,7 @@ def _split_train_val( @classmethod def from_data_source( cls, - data_source: DataSource, + data_source: str, train_data: Any = None, val_data: Any = None, test_data: Any = None, @@ -350,16 +350,19 @@ def from_data_source( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **kwargs: Any, + data_source_kwargs: Dict[str, Any] = {}, + preprocess_kwargs: Dict[str, Any] = {}, ) -> 'DataModule': preprocess = preprocess or cls.preprocess_cls( train_transform, val_transform, test_transform, predict_transform, - **kwargs, + **preprocess_kwargs, ) + data_source = preprocess.data_source_of_name(data_source)(**data_source_kwargs) + train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets( train_data, val_data, @@ -396,12 +399,11 @@ def from_folders( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **kwargs: Any, + data_source_kwargs: Dict[str, Any] = {}, + preprocess_kwargs: Dict[str, Any] = {}, ) -> 'DataModule': - data_source = (preprocess or cls.preprocess_cls).data_source_of_type(FoldersDataSource)() - return cls.from_data_source( - data_source, + DefaultDataSources.FOLDERS, train_folder, val_folder, test_folder, @@ -415,6 +417,8 @@ def from_folders( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + data_source_kwargs=data_source_kwargs, + preprocess_kwargs=preprocess_kwargs, ) @classmethod @@ -436,12 +440,11 @@ def from_files( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **kwargs: Any, + data_source_kwargs: Dict[str, Any] = {}, + preprocess_kwargs: Dict[str, Any] = {}, ) -> 'DataModule': - data_source = (preprocess or cls.preprocess_cls).data_source_of_type(FilesDataSource)() - return cls.from_data_source( - data_source, + DefaultDataSources.FILES, (train_files, train_targets), (val_files, val_targets), (test_files, test_targets), @@ -455,6 +458,8 @@ def from_files( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + data_source_kwargs=data_source_kwargs, + preprocess_kwargs=preprocess_kwargs, ) @classmethod @@ -476,12 +481,11 @@ def from_tensors( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **kwargs: Any, + data_source_kwargs: Dict[str, Any] = {}, + preprocess_kwargs: Dict[str, Any] = {}, ) -> 'DataModule': - data_source = (preprocess or cls.preprocess_cls).data_source_of_type(TensorDataSource)() - return cls.from_data_source( - data_source, + DefaultDataSources.TENSOR, (train_data, train_targets), (val_data, val_targets), (test_data, test_targets), @@ -495,6 +499,8 @@ def from_tensors( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + data_source_kwargs=data_source_kwargs, + preprocess_kwargs=preprocess_kwargs, ) @classmethod @@ -516,12 +522,11 @@ def from_numpy( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **kwargs: Any, + data_source_kwargs: Dict[str, Any] = {}, + preprocess_kwargs: Dict[str, Any] = {}, ) -> 'DataModule': - data_source = (preprocess or cls.preprocess_cls).data_source_of_type(NumpyDataSource)() - return cls.from_data_source( - data_source, + DefaultDataSources.NUMPY, (train_data, train_targets), (val_data, val_targets), (test_data, test_targets), @@ -535,4 +540,36 @@ def from_numpy( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + data_source_kwargs=data_source_kwargs, + preprocess_kwargs=preprocess_kwargs, + ) + + @classmethod + def from_csv( + cls, + input_fields: Union[str, List[str]], + target_fields: Optional[Union[str, List[str]]] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + data_source_kwargs: Dict[str, Any] = {}, + preprocess_kwargs: Dict[str, Any] = {}, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.CSV, + (train_file, input_fields, target_fields), + (val_file, input_fields, target_fields), + (test_file, input_fields, target_fields), + (predict_file, input_fields, target_fields), + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, ) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index abff14004b..8e357e3a96 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import pathlib -from abc import ABC from dataclasses import dataclass -from enum import Enum from inspect import signature -from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union import numpy as np import torch @@ -28,7 +25,7 @@ from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset from flash.data.process import ProcessState, Properties -from flash.data.utils import _STAGES_PREFIX, CurrentRunningStageFuncContext +from flash.data.utils import CurrentRunningStageFuncContext def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool: @@ -60,7 +57,7 @@ def __setattr__(self, key, value): DATA_TYPE = TypeVar("DATA_TYPE") -class DataSource(Generic[DATA_TYPE], Properties, Module, ABC): +class DataSource(Generic[DATA_TYPE], Properties, Module): def load_data(self, data: DATA_TYPE, @@ -99,7 +96,11 @@ def generate_dataset( data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: - if data is not None: + is_none = data is None + if isinstance(data, Sequence): + is_none = data[0] is None + + if not is_none: from flash.data.data_pipeline import DataPipeline mock_dataset = MockDataset() @@ -132,6 +133,7 @@ class DefaultDataSources(LightningEnum): FILES = "files" NUMPY = "numpy" TENSOR = "tensor" + CSV = "csv" class DefaultDataKeys(LightningEnum): @@ -144,7 +146,7 @@ def __hash__(self) -> int: return hash(self.value) -class FoldersDataSource(DataSource[str], ABC): +class FoldersDataSource(DataSource[str]): def __init__(self, extensions: Optional[Tuple[str, ...]] = None): super().__init__() @@ -190,7 +192,6 @@ def load_data(self, data: str, dataset: Optional[Any] = None) -> Iterable[Mappin class SequenceDataSource( Generic[SEQUENCE_DATA_TYPE], DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence[Any]]]], - ABC, ): def __init__(self, labels: Optional[Sequence[str]] = None): @@ -219,7 +220,7 @@ def predict_load_data(self, data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapp return [{DefaultDataKeys.INPUT: input} for input in data] -class FilesDataSource(SequenceDataSource[str], ABC): +class FilesDataSource(SequenceDataSource[str]): def __init__(self, extensions: Optional[Tuple[str, ...]] = None, labels: Optional[Sequence[str]] = None): super().__init__(labels=labels) @@ -247,9 +248,9 @@ def predict_load_data(self, data: Sequence[str]) -> Sequence[Mapping[str, Any]]: ) -class TensorDataSource(SequenceDataSource[torch.Tensor], ABC): +class TensorDataSource(SequenceDataSource[torch.Tensor]): """""" # TODO: Some docstring here -class NumpyDataSource(SequenceDataSource[np.ndarray], ABC): +class NumpyDataSource(SequenceDataSource[np.ndarray]): """""" # TODO: Some docstring here diff --git a/flash/data/process.py b/flash/data/process.py index 3b29175e6d..f6ec29c55a 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -54,6 +54,8 @@ def __init__(self): self._state: Dict[Type[ProcessState], ProcessState] = {} def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]: + if state_type in self._state: + return self._state[state_type] if self._data_pipeline_state is not None: return self._data_pipeline_state.get_state(state_type) else: diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index d836cb5552..57fcc0a582 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -13,7 +13,7 @@ # limitations under the License. import os from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union from datasets import DatasetDict, load_dataset from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -21,130 +21,58 @@ from transformers import AutoTokenizer, default_data_collator from transformers.modeling_outputs import SequenceClassifierOutput -from flash.core.classification import ClassificationState from flash.data.auto_dataset import AutoDataset +from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule +from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, LabelsState from flash.data.process import Postprocess, Preprocess -class TextClassificationPreprocess(Preprocess): - - def __init__( - self, - input: str, - backbone: str, - max_length: int, - target: str, - filetype: str, - train_file: Optional[str], - label_to_class_mapping: Optional[Dict[str, int]], - ): - """ - This class contains the preprocessing logic for text classification - - Args: - # tokenizer: Hugging Face Tokenizer. # TODO: Add back a tokenizer argument and make backbone optional? - input: The field storing the text to be classified. - max_length: Maximum number of tokens within a single sentence. - target: The field storing the class id of the associated text. - filetype: .csv or .json format type. - label_to_class_mapping: Dictionary mapping target labels to class indexes. - """ +class TextDataSource(DataSource): + def __init__(self, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): super().__init__() - if label_to_class_mapping is None: - if train_file is not None: - label_to_class_mapping = self.get_label_to_class_mapping(train_file, target, filetype) - else: - raise MisconfigurationException( - "Either ``label_to_class_mapping`` or ``train_file`` needs to be provided" - ) - self.backbone = backbone - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - self.input = input - self.filetype = filetype self.max_length = max_length - self.label_to_class_mapping = label_to_class_mapping - self.target = target - - self._tokenize_fn = partial( - self._tokenize_fn, - tokenizer=self.tokenizer, - input=self.input, - max_length=self.max_length, - truncation=True, - padding="max_length" - ) - - class_to_label_mapping = ['CLASS_UNKNOWN'] * (max(self.label_to_class_mapping.values()) + 1) - for label, cls in self.label_to_class_mapping.items(): - class_to_label_mapping[cls] = label - self.set_state(ClassificationState(class_to_label_mapping)) - - def get_state_dict(self) -> Dict[str, Any]: - return { - "input": self.input, - "backbone": self.backbone, - "max_length": self.max_length, - "target": self.target, - "filetype": self.filetype, - "label_to_class_mapping": self.label_to_class_mapping, - } - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return cls(**state_dict) - def per_batch_transform(self, batch: Any) -> Any: - if "labels" not in batch: - # todo: understand why an extra dimension has been added. - if batch["input_ids"].dim() == 3: - batch["input_ids"] = batch["input_ids"].squeeze(0) - return batch + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - @staticmethod def _tokenize_fn( + self, ex: Union[Dict[str, str], str], - tokenizer=None, input: str = None, - max_length: int = None, - **kwargs ) -> Callable: """This function is used to tokenize sentences using the provided tokenizer.""" if isinstance(ex, dict): ex = ex[input] - return tokenizer(ex, max_length=max_length, **kwargs) + return self.tokenizer(ex, max_length=self.max_length, truncation=True, padding="max_length") - def collate(self, samples: Any) -> Tensor: - """Override to convert a set of samples to a batch""" - if isinstance(samples, dict): - samples = [samples] - return default_data_collator(samples) - - def _transform_label(self, ex: Dict[str, str]): - ex[self.target] = self.label_to_class_mapping[ex[self.target]] + def _transform_label(self, label_to_class_mapping: Dict[str, int], target: str, ex: Dict[str, Union[int, str]]): + ex[target] = label_to_class_mapping[ex[target]] return ex - @staticmethod - def get_label_to_class_mapping(file: str, target: str, filetype: str) -> Dict[str, int]: - data_files = {'train': file} - dataset_dict = load_dataset(filetype, data_files=data_files) - label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(dataset_dict['train'][target])))))} - return label_to_class_mapping + +class TextFileDataSource(TextDataSource): + + def __init__(self, filetype: str, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): + super().__init__(backbone=backbone, max_length=max_length) + + self.filetype = filetype def load_data( self, - filepath: str, - dataset: AutoDataset, + data: Tuple[str, Union[str, List[str]], Union[str, List[str]]], + dataset: Optional[Any] = None, columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), - use_full: bool = True - ): + use_full: bool = True, + ) -> Union[Sequence[Mapping[str, Any]]]: + csv_file, input, target = data + data_files = {} - stage = dataset.running_stage.value - data_files[stage] = str(filepath) + stage = self.running_stage.value + data_files[stage] = str(csv_file) # FLASH_TESTING is set in the CI to run faster. if use_full and os.getenv("FLASH_TESTING", "0") == "0": @@ -155,37 +83,88 @@ def load_data( stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] }) - dataset_dict = dataset_dict.map(self._tokenize_fn, batched=True) + if self.training: + labels = list(sorted(list(set(dataset_dict[stage][target])))) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) # convert labels to ids - if not self.predicting: - dataset_dict = dataset_dict.map(self._transform_label) + # if not self.predicting: + if labels is not None: + labels = labels.labels + label_to_class_mapping = {v: k for k, v in enumerate(labels)} + dataset_dict = dataset_dict.map(partial(self._transform_label, label_to_class_mapping, target)) - dataset_dict = dataset_dict.map(self._tokenize_fn, batched=True) + dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input), batched=True) # Hugging Face models expect target to be named ``labels``. - if not self.predicting and self.target != "labels": - dataset_dict.rename_column_(self.target, "labels") + if not self.predicting and target != "labels": + dataset_dict.rename_column_(target, "labels") dataset_dict.set_format("torch", columns=columns) - if not self.predicting: - dataset.num_classes = len(self.label_to_class_mapping) - return dataset_dict[stage] - def predict_load_data(self, sample: Any, dataset: AutoDataset): - if isinstance(sample, str) and os.path.isfile(sample) and sample.endswith(".csv"): - return self.load_data(sample, dataset, columns=["input_ids", "attention_mask"]) - else: - if isinstance(sample, str): - sample = [sample] + def predict_load_data(self, data: Any, dataset: AutoDataset): + return self.load_data(data, dataset, columns=["input_ids", "attention_mask"]) + + +class TextCSVDataSource(TextFileDataSource): + + def __init__(self, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): + super().__init__("csv", backbone=backbone, max_length=max_length) + + +class TextJSONDataSource(TextFileDataSource): + + def __init__(self, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): + super().__init__("json", backbone=backbone, max_length=max_length) + + +class TextSentencesDataSource(TextDataSource): - if isinstance(sample, list) and all(isinstance(s, str) for s in sample): - return [self._tokenize_fn(s) for s in sample] + def __init__(self, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): + super().__init__(backbone=backbone, max_length=max_length) - else: - raise MisconfigurationException("Currently, we support only list of sentences") + def load_data( + self, + data: Union[str, List[str]], + dataset: Optional[Any] = None, + ) -> Union[Sequence[Mapping[str, Any]]]: + + if isinstance(data, str): + data = [data] + return [self._tokenize_fn(s, ) for s in data] + + +class TextClassificationPreprocess(Preprocess): + + data_sources = { + DefaultDataSources.CSV: TextCSVDataSource, + "sentences": TextSentencesDataSource, + } + + def get_state_dict(self) -> Dict[str, Any]: + return {} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + return cls(**state_dict) + + def per_batch_transform(self, batch: Any) -> Any: + if "labels" not in batch: + # todo: understand why an extra dimension has been added. + if batch["input_ids"].dim() == 3: + batch["input_ids"] = batch["input_ids"].squeeze(0) + return batch + + def collate(self, samples: Any) -> Tensor: + """Override to convert a set of samples to a batch""" + if isinstance(samples, dict): + samples = [samples] + return default_data_collator(samples) class TextClassificationPostProcess(Postprocess): @@ -202,117 +181,37 @@ class TextClassificationData(DataModule): preprocess_cls = TextClassificationPreprocess postprocess_cls = TextClassificationPostProcess - @property - def num_classes(self) -> int: - return len(self._preprocess.label_to_class_mapping) - @classmethod - def from_files( + def from_csv( cls, - train_file: Optional[str], - input: Optional[str] = 'input', - target: Optional[str] = 'labels', - filetype: str = "csv", - backbone: str = "prajjwal1/bert-tiny", + input_fields: Union[str, List[str]], + target_fields: Optional[Union[str, List[str]]] = None, + train_file: Optional[str] = None, val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, + backbone: str = "prajjwal1/bert-tiny", max_length: int = 128, - label_to_class_mapping: Optional[dict] = None, - batch_size: int = 16, - num_workers: Optional[int] = None, + data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ) -> 'TextClassificationData': - """Creates a TextClassificationData object from files. - - Args: - train_file: Path to training data. - input: The field storing the text to be classified. - target: The field storing the class id of the associated text. - filetype: .csv or .json - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - val_file: Path to validation data. - test_file: Path to test data. - batch_size: the batchsize to use for parallel loading. Defaults to 64. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - - Returns: - TextClassificationData: The constructed data module. - - Examples:: - - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, target="fraud", - num_cols=["account_value"], - cat_cols=["account_type"]) - - """ - preprocess = preprocess or cls.preprocess_cls( - input, - backbone, - max_length, - target, - filetype, - train_file, - label_to_class_mapping, - ) - - postprocess = postprocess or cls.postprocess_cls() - - return cls.from_load_data_inputs( - train_load_data_input=train_file, - val_load_data_input=val_file, - test_load_data_input=test_file, - predict_load_data_input=predict_file, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) - - @classmethod - def from_file( - cls, - predict_file: str, - input: str, - backbone="bert-base-cased", - filetype="csv", - max_length: int = 128, - label_to_class_mapping: Optional[dict] = None, - batch_size: int = 16, + val_split: Optional[float] = None, + batch_size: int = 4, num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ) -> 'TextClassificationData': - """Creates a TextClassificationData object from files. - - Args: - - predict_file: Path to training data. - input: The field storing the text to be classified. - filetype: .csv or .json - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - batch_size: the batchsize to use for parallel loading. Defaults to 64. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - """ - return cls.from_files( - None, - input=input, - target=None, - filetype=filetype, - backbone=backbone, - val_file=None, - test_file=None, + ) -> 'DataModule': + return super().from_csv( + input_fields, + target_fields, + train_file=train_file, + val_file=val_file, + test_file=test_file, predict_file=predict_file, - max_length=max_length, - label_to_class_mapping=label_to_class_mapping, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, batch_size=batch_size, num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, + data_source_kwargs=dict( + backbone=backbone, + max_length=max_length, + ), ) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index f11a50eb01..284725175a 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -21,6 +21,7 @@ from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry +from flash.data.data_source import DefaultDataKeys from flash.data.process import Serializer from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES @@ -110,6 +111,22 @@ def __init__( nn.Linear(num_features, num_classes), ) + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + def forward(self, x) -> torch.Tensor: x = self.backbone(x) return self.head(x) diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index efbcac71ea..5f0000cf40 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -19,12 +19,12 @@ download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the data -datamodule = TextClassificationData.from_files( +datamodule = TextClassificationData.from_csv( train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", test_file="data/imdb/test.csv", - input="review", - target="sentiment", + input_fields="review", + target_fields="sentiment", batch_size=16, ) diff --git a/flash_examples/predict/text_classification.py b/flash_examples/predict/text_classification.py index e81fd17c52..e9a700a907 100644 --- a/flash_examples/predict/text_classification.py +++ b/flash_examples/predict/text_classification.py @@ -21,7 +21,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the model from a checkpoint -model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") +model = TextClassifier.load_from_checkpoint("../finetuning/text_classification_model.pt") model.serializer = Labels() @@ -32,15 +32,17 @@ "I come from Bulgaria where it 's almost impossible to have a tornado.", "Very, very afraid.", "This guy has done a great job with this movie!", -]) +], + data_source="sentences") print(predictions) # 2b. Or generate predictions from a sheet file! -datamodule = TextClassificationData.from_file( +datamodule = TextClassificationData.from_csv( + "review", predict_file="data/imdb/predict.csv", - input="review", - # use the same data pre-processing values we used to predict in 2a - preprocess=model.preprocess, + # input="review", + # # use the same data pre-processing values we used to predict in 2a + # preprocess=model.preprocess, ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) From 6e0f69d857190197de225fe7712ee8398cc94e8f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 May 2021 14:24:14 +0100 Subject: [PATCH 21/78] Small update --- flash_examples/predict/text_classification.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/flash_examples/predict/text_classification.py b/flash_examples/predict/text_classification.py index e9a700a907..21d19854df 100644 --- a/flash_examples/predict/text_classification.py +++ b/flash_examples/predict/text_classification.py @@ -26,23 +26,22 @@ model.serializer = Labels() # 2a. Classify a few sentences! How was the movie? -predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado.", - "Very, very afraid.", - "This guy has done a great job with this movie!", -], - data_source="sentences") +predictions = model.predict( + [ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", + "Very, very afraid.", + "This guy has done a great job with this movie!", + ], + data_source="sentences", +) print(predictions) # 2b. Or generate predictions from a sheet file! datamodule = TextClassificationData.from_csv( "review", predict_file="data/imdb/predict.csv", - # input="review", - # # use the same data pre-processing values we used to predict in 2a - # preprocess=model.preprocess, ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) From a2082bcd734a363d2bb8d0be5dd6a3ee74cb0a20 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 May 2021 17:42:00 +0100 Subject: [PATCH 22/78] Add tabular --- flash/core/model.py | 2 +- flash/data/auto_dataset.py | 2 +- flash/data/data_module.py | 33 +- flash/data/data_source.py | 4 + flash/data/process.py | 17 +- flash/tabular/classification/data/data.py | 409 +++++++++--------- flash/tabular/classification/model.py | 16 +- .../finetuning/tabular_classification.py | 12 +- .../predict/tabular_classification.py | 4 +- 9 files changed, 238 insertions(+), 261 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 3852148395..d4a5864c71 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -317,7 +317,7 @@ def build_data_pipeline( data_source = data_source or old_data_source if isinstance(data_source, str): - data_source = preprocess.data_source_of_name(data_source)() + data_source = preprocess.data_source_of_name(data_source) data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 1c468dfb37..4ede9bb5b7 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -73,8 +73,8 @@ def running_stage(self, running_stage: RunningStage) -> None: ) def _call_load_sample(self, sample: Any) -> Any: - sample = dict(**sample) if self.load_sample: + sample = dict(**sample) with self._load_sample_context: parameters = signature(self.load_sample).parameters if len(parameters) > 1 and self.DATASET_KEY in parameters: diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 13a43b6301..dac3d9d144 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -350,8 +350,7 @@ def from_data_source( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - data_source_kwargs: Dict[str, Any] = {}, - preprocess_kwargs: Dict[str, Any] = {}, + **preprocess_kwargs: Any, ) -> 'DataModule': preprocess = preprocess or cls.preprocess_cls( train_transform, @@ -361,7 +360,7 @@ def from_data_source( **preprocess_kwargs, ) - data_source = preprocess.data_source_of_name(data_source)(**data_source_kwargs) + data_source = preprocess.data_source_of_name(data_source) train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets( train_data, @@ -399,8 +398,7 @@ def from_folders( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - data_source_kwargs: Dict[str, Any] = {}, - preprocess_kwargs: Dict[str, Any] = {}, + **preprocess_kwargs: Any, ) -> 'DataModule': return cls.from_data_source( DefaultDataSources.FOLDERS, @@ -417,8 +415,7 @@ def from_folders( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - data_source_kwargs=data_source_kwargs, - preprocess_kwargs=preprocess_kwargs, + **preprocess_kwargs, ) @classmethod @@ -440,8 +437,7 @@ def from_files( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - data_source_kwargs: Dict[str, Any] = {}, - preprocess_kwargs: Dict[str, Any] = {}, + **preprocess_kwargs: Any, ) -> 'DataModule': return cls.from_data_source( DefaultDataSources.FILES, @@ -458,8 +454,7 @@ def from_files( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - data_source_kwargs=data_source_kwargs, - preprocess_kwargs=preprocess_kwargs, + **preprocess_kwargs, ) @classmethod @@ -481,8 +476,7 @@ def from_tensors( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - data_source_kwargs: Dict[str, Any] = {}, - preprocess_kwargs: Dict[str, Any] = {}, + **preprocess_kwargs: Any, ) -> 'DataModule': return cls.from_data_source( DefaultDataSources.TENSOR, @@ -499,8 +493,7 @@ def from_tensors( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - data_source_kwargs=data_source_kwargs, - preprocess_kwargs=preprocess_kwargs, + **preprocess_kwargs, ) @classmethod @@ -522,8 +515,7 @@ def from_numpy( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - data_source_kwargs: Dict[str, Any] = {}, - preprocess_kwargs: Dict[str, Any] = {}, + **preprocess_kwargs: Any, ) -> 'DataModule': return cls.from_data_source( DefaultDataSources.NUMPY, @@ -540,8 +532,7 @@ def from_numpy( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - data_source_kwargs=data_source_kwargs, - preprocess_kwargs=preprocess_kwargs, + **preprocess_kwargs, ) @classmethod @@ -558,8 +549,7 @@ def from_csv( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - data_source_kwargs: Dict[str, Any] = {}, - preprocess_kwargs: Dict[str, Any] = {}, + **preprocess_kwargs: Any, ) -> 'DataModule': return cls.from_data_source( DefaultDataSources.CSV, @@ -572,4 +562,5 @@ def from_csv( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + **preprocess_kwargs, ) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 8e357e3a96..fac7f77856 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -135,6 +135,10 @@ class DefaultDataSources(LightningEnum): TENSOR = "tensor" CSV = "csv" + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) + class DefaultDataKeys(LightningEnum): diff --git a/flash/data/process.py b/flash/data/process.py index f6ec29c55a..ff8fa18da3 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -304,14 +304,13 @@ def load_data(cls, path_to_data: str) -> Iterable: """ - data_sources: Optional[Dict[str, Type['DataSource']]] - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + data_sources: Optional[Dict[str, 'DataSource']] = None, ): super().__init__() @@ -337,6 +336,7 @@ def __init__( self.test_transform = convert_to_modules(self._test_transform) self.predict_transform = convert_to_modules(self._predict_transform) + self._data_sources = data_sources self._callbacks: List[FlashCallback] = [] def _save_to_state_dict(self, destination, prefix, keep_vars): @@ -514,17 +514,8 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: """ return self.current_transform(batch) - @classmethod - def data_source_of_type(cls, data_source_type: Type[DATA_SOURCE_TYPE]) -> Optional[Type[DATA_SOURCE_TYPE]]: - data_sources = cls.data_sources - for data_source in data_sources.values(): - if issubclass(data_source, data_source_type): - return data_source - return None - - @classmethod - def data_source_of_name(cls, data_source_name: str) -> Optional[Type[DATA_SOURCE_TYPE]]: - data_sources = cls.data_sources + def data_source_of_name(self, data_source_name: str) -> Optional[DATA_SOURCE_TYPE]: + data_sources = self._data_sources if data_source_name in data_sources: return data_sources[data_source_name] return None diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index 87e7751bff..f85500b422 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -11,18 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd from pandas.core.frame import DataFrame from pytorch_lightning.utilities.exceptions import MisconfigurationException -from sklearn.model_selection import train_test_split -from torch.utils.data import Dataset from flash.core.classification import LabelsState -from flash.data.auto_dataset import AutoDataset from flash.data.data_module import DataModule +from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess from flash.tabular.classification.data.dataset import ( _compute_normalization, @@ -33,23 +31,21 @@ ) -class TabularPreprocess(Preprocess): +class TabularDataFrameDataSource(DataSource[DataFrame]): def __init__( self, - cat_cols: List[str], - num_cols: List[str], - target_col: str, - mean: DataFrame, - std: DataFrame, - codes: Dict[str, Any], - target_codes: Optional[Dict[str, Any]], - classes: List[str], - num_classes: int, - is_regression: bool, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, ): super().__init__() - self.set_state(LabelsState(classes)) self.cat_cols = cat_cols self.num_cols = num_cols @@ -58,28 +54,16 @@ def __init__( self.std = std self.codes = codes self.target_codes = target_codes - self.num_classes = num_classes self.is_regression = is_regression - def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: - return { - "cat_cols": self.cat_cols, - "num_cols": self.num_cols, - "target_col": self.target_col, - "mean": self.mean, - "std": self.std, - "codes": self.codes, - "target_codes": self.target_codes, - "classes": self.num_classes, - "num_classes": self.num_classes, - "is_regression": self.is_regression, - } - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': - return cls(**state_dict) + self.set_state(LabelsState(classes)) + self.num_classes = len(classes) - def common_load_data(self, df: DataFrame, dataset: AutoDataset): + def common_load_data( + self, + df: DataFrame, + dataset: Optional[Any] = None, + ): # impute_data # compute train dataset stats dfs = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, @@ -87,181 +71,124 @@ def common_load_data(self, df: DataFrame, dataset: AutoDataset): df = dfs[0] - dataset.num_samples = len(df) + if dataset is not None: + dataset.num_samples = len(df) + cat_vars = _to_cat_vars_numpy(df, self.cat_cols) num_vars = _to_num_vars_numpy(df, self.num_cols) - cat_vars = np.stack(cat_vars, 1) if len(cat_vars) else np.zeros((len(self), 0)) - num_vars = np.stack(num_vars, 1) if len(num_vars) else np.zeros((len(self), 0)) + cat_vars = np.stack(cat_vars, 1) # if len(cat_vars) else np.zeros((len(self), 0)) + num_vars = np.stack(num_vars, 1) # if len(num_vars) else np.zeros((len(self), 0)) return df, cat_vars, num_vars - def load_data(self, df: DataFrame, dataset: AutoDataset): - df, cat_vars, num_vars = self.common_load_data(df, dataset) + def load_data(self, data: DataFrame, dataset: Optional[Any] = None): + df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) - return [((c, n), t) for c, n, t in zip(cat_vars, num_vars, target)] + return [{ + DefaultDataKeys.INPUT: (c, n), + DefaultDataKeys.TARGET: t + } for c, n, t in zip(cat_vars, num_vars, target)] - def predict_load_data(self, sample: Union[str, DataFrame], dataset: AutoDataset): - df = pd.read_csv(sample) if isinstance(sample, str) else sample - _, cat_vars, num_vars = self.common_load_data(df, dataset) - return list(zip(cat_vars, num_vars)) + def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): + _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) + return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] - @classmethod - def from_data( - cls, - train_df: DataFrame, - val_df: Optional[DataFrame], - test_df: Optional[DataFrame], - predict_df: Optional[DataFrame], - target_col: str, - num_cols: List[str], - cat_cols: List[str], - is_regression: bool, - ) -> 'TabularPreprocess': - if train_df is None: - raise MisconfigurationException("train_df is required to instantiate the TabularPreprocess") +class TabularCSVDataSource(TabularDataFrameDataSource): - dfs = [train_df] + def load_data(self, data: str, dataset: Optional[Any] = None): + return super().load_data(pd.read_csv(data), dataset=dataset) - if val_df is not None: - dfs += [val_df] + def predict_load_data(self, data: str, dataset: Optional[Any] = None): + return super().predict_load_data(pd.read_csv(data), dataset=dataset) - if test_df is not None: - dfs += [test_df] - if predict_df is not None: - dfs += [predict_df] +class TabularPreprocess(Preprocess): - mean, std = _compute_normalization(dfs[0], num_cols) - classes = list(dfs[0][target_col].unique()) - num_classes = len(classes) - if dfs[0][target_col].dtype == object: - # if the target_col is a category, not an int - target_codes = _generate_codes(dfs, [target_col]) - else: - target_codes = None - codes = _generate_codes(dfs, cat_cols) + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, + ): + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.classes = classes + self.is_regression = is_regression - return cls( - cat_cols, - num_cols, - target_col, - mean, - std, - codes, - target_codes, - classes, - num_classes, - is_regression, + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: TabularCSVDataSource( + cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression + ), + "df": TabularDataFrameDataSource( + cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression + ), + } ) + def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: + return { + "cat_cols": self.cat_cols, + "num_cols": self.num_cols, + "target_col": self.target_col, + "mean": self.mean, + "std": self.std, + "codes": self.codes, + "target_codes": self.target_codes, + "classes": self.classes, + "is_regression": self.is_regression, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': + return cls(**state_dict) + class TabularData(DataModule): """Data module for tabular tasks""" preprocess_cls = TabularPreprocess - def __init__( - self, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - test_dataset: Optional[Dataset] = None, - predict_dataset: Optional[Dataset] = None, - batch_size: int = 1, - num_workers: Optional[int] = 0, - ) -> None: - super().__init__( - train_dataset, - val_dataset, - test_dataset, - predict_dataset, - batch_size=batch_size, - num_workers=num_workers, - ) - - self._preprocess: Optional[Preprocess] = None - @property def codes(self) -> Dict[str, str]: - return self._preprocess.codes + return self._data_source.codes @property def num_classes(self) -> int: - return self._preprocess.num_classes + return self._data_source.num_classes @property def cat_cols(self) -> Optional[List[str]]: - return self._preprocess.cat_cols + return self._data_source.cat_cols @property def num_cols(self) -> Optional[List[str]]: - return self._preprocess.num_cols + return self._data_source.num_cols @property def num_features(self) -> int: return len(self.cat_cols) + len(self.num_cols) - @classmethod - def from_csv( - cls, - target_col: str, - train_csv: Optional[str] = None, - categorical_cols: Optional[List] = None, - numerical_cols: Optional[List] = None, - val_csv: Optional[str] = None, - test_csv: Optional[str] = None, - predict_csv: Optional[str] = None, - batch_size: int = 8, - num_workers: Optional[int] = None, - val_size: Optional[float] = None, - test_size: Optional[float] = None, - preprocess: Optional[Preprocess] = None, - **pandas_kwargs, - ): - """Creates a TextClassificationData object from pandas DataFrames. - - Args: - train_csv: Train data csv file. - target_col: The column containing the class id. - categorical_cols: The list of categorical columns. - numerical_cols: The list of numerical columns. - val_csv: Validation data csv file. - test_csv: Test data csv file. - batch_size: The batchsize to use for parallel loading. Defaults to 64. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - val_size: Float between 0 and 1 to create a validation dataset from train dataset. - test_size: Float between 0 and 1 to create a test dataset from train validation. - preprocess: Preprocess to be used within this DataModule DataPipeline. - - Returns: - TabularData: The constructed data module. - - Examples:: - - text_data = TabularData.from_files("train.csv", label_field="class", text_field="sentence") - """ - train_df = pd.read_csv(train_csv, **pandas_kwargs) - val_df = pd.read_csv(val_csv, **pandas_kwargs) if val_csv else None - test_df = pd.read_csv(test_csv, **pandas_kwargs) if test_csv else None - predict_df = pd.read_csv(predict_csv, **pandas_kwargs) if predict_csv else None - - return cls.from_df( - train_df, - target_col, - categorical_cols, - numerical_cols, - val_df, - test_df, - predict_df, - batch_size, - num_workers, - val_size, - test_size, - preprocess=preprocess, - ) - @property def emb_sizes(self) -> list: """Recommended embedding sizes.""" @@ -273,25 +200,6 @@ def emb_sizes(self) -> list: emb_dims = [max(int(n**0.25), 16) for n in num_classes] return list(zip(num_classes, emb_dims)) - @staticmethod - def _split_dataframe( - train_df: DataFrame, - val_df: Optional[DataFrame] = None, - test_df: Optional[DataFrame] = None, - val_size: float = None, - test_size: float = None, - ): - if val_df is None and isinstance(val_size, float) and isinstance(test_size, float): - assert 0 < val_size < 1 - assert 0 < test_size < 1 - train_df, val_df = train_test_split(train_df, test_size=(val_size + test_size)) - - if test_df is None and isinstance(test_size, float): - assert 0 < test_size < 1 - val_df, test_df = train_test_split(val_df, test_size=test_size) - - return train_df, val_df, test_df - @staticmethod def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]): if cat_cols is None and num_cols is None: @@ -300,21 +208,58 @@ def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]): return cat_cols or [], num_cols or [] @classmethod - def from_df( + def compute_state( cls, train_df: DataFrame, + val_df: Optional[DataFrame], + test_df: Optional[DataFrame], + predict_df: Optional[DataFrame], + target_col: str, + num_cols: List[str], + cat_cols: List[str], + ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]: + + if train_df is None: + raise MisconfigurationException("train_df is required to instantiate the TabularDataFrameDataSource") + + dfs = [train_df] + + if val_df is not None: + dfs += [val_df] + + if test_df is not None: + dfs += [test_df] + + if predict_df is not None: + dfs += [predict_df] + + mean, std = _compute_normalization(dfs[0], num_cols) + classes = list(dfs[0][target_col].unique()) + + if dfs[0][target_col].dtype == object: + # if the target_col is a category, not an int + target_codes = _generate_codes(dfs, [target_col]) + else: + target_codes = None + codes = _generate_codes(dfs, cat_cols) + + return mean, std, classes, codes, target_codes + + @classmethod + def from_df( + cls, + categorical_cols: List, + numerical_cols: List, target_col: str, - categorical_cols: Optional[List] = None, - numerical_cols: Optional[List] = None, + train_df: DataFrame, val_df: Optional[DataFrame] = None, test_df: Optional[DataFrame] = None, predict_df: Optional[DataFrame] = None, - batch_size: int = 8, - num_workers: Optional[int] = None, - val_size: float = None, - test_size: float = None, is_regression: bool = False, preprocess: Optional[Preprocess] = None, + val_split: float = None, + batch_size: int = 8, + num_workers: Optional[int] = None, ): """Creates a TabularData object from pandas DataFrames. @@ -329,8 +274,7 @@ def from_df( num_workers: The number of workers to use for parallelized loading. Defaults to None which equals the number of available CPU threads, or 0 for Darwin platform. - val_size: Float between 0 and 1 to create a validation dataset from train dataset. - test_size: Float between 0 and 1 to create a test dataset from train validation. + val_split: Float between 0 and 1 to create a validation dataset from train dataset. preprocess: Preprocess to be used within this DataModule DataPipeline. Returns: @@ -342,25 +286,58 @@ def from_df( """ categorical_cols, numerical_cols = cls._sanetize_cols(categorical_cols, numerical_cols) - train_df, val_df, test_df = cls._split_dataframe(train_df, val_df, test_df, val_size, test_size) - - preprocess = preprocess or cls.preprocess_cls.from_data( - train_df, - val_df, - test_df, - predict_df, - target_col, - numerical_cols, - categorical_cols, - is_regression, + mean, std, classes, codes, target_codes = cls.compute_state( + train_df, val_df, test_df, predict_df, target_col, numerical_cols, categorical_cols + ) + + return cls.from_data_source( + data_source="df", + train_data=train_df, + val_data=val_df, + test_data=test_df, + predict_data=predict_df, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + cat_cols=categorical_cols, + num_cols=numerical_cols, + target_col=target_col, + mean=mean, + std=std, + codes=codes, + target_codes=target_codes, + classes=classes, + is_regression=is_regression, ) - return cls.from_load_data_inputs( - train_load_data_input=train_df, - val_load_data_input=val_df, - test_load_data_input=test_df, - predict_load_data_input=predict_df, + @classmethod + def from_csv( + cls, + categorical_fields: Union[str, List[str]], + numerical_fields: Union[str, List[str]], + target_field: Optional[str] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + is_regression: bool = False, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + ) -> 'DataModule': + return cls.from_df( + categorical_fields, + numerical_fields, + target_field, + train_df=pd.read_csv(train_file) if train_file is not None else None, + val_df=pd.read_csv(val_file) if val_file is not None else None, + test_df=pd.read_csv(test_file) if test_file is not None else None, + predict_df=pd.read_csv(predict_file) if predict_file is not None else None, + is_regression=is_regression, + preprocess=preprocess, + val_split=val_split, batch_size=batch_size, num_workers=num_workers, - preprocess=preprocess ) diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 7e399aaaa2..6fde330784 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -18,6 +18,7 @@ from torchmetrics import Metric from flash.core.classification import ClassificationTask +from flash.data.data_source import DefaultDataKeys from flash.data.process import Serializer from flash.utils.imports import _TABNET_AVAILABLE @@ -80,9 +81,22 @@ def __init__( def forward(self, x_in) -> torch.Tensor: # TabNet takes single input, x_in is composed of (categorical, numerical) x = torch.cat([x for x in x_in if x.numel()], dim=1) - return F.softmax(self.model(x)[0], -1) + return self.model(x)[0] + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) return self(batch) @classmethod diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index 9d5b8ad256..ad8a949455 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -22,12 +22,12 @@ # 2. Load the data datamodule = TabularData.from_csv( - target_col="Survived", - train_csv="./data/titanic/titanic.csv", - test_csv="./data/titanic/test.csv", - categorical_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - numerical_cols=["Fare"], - val_size=0.25, + ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + ["Fare"], + target_field="Survived", + train_file="./data/titanic/titanic.csv", + test_file="./data/titanic/test.csv", + val_split=0.25, ) # 3. Build the model diff --git a/flash_examples/predict/tabular_classification.py b/flash_examples/predict/tabular_classification.py index a874d1f99f..dcee9c859d 100644 --- a/flash_examples/predict/tabular_classification.py +++ b/flash_examples/predict/tabular_classification.py @@ -19,10 +19,10 @@ download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") # 2. Load the model from a checkpoint -model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt") +model = TabularClassifier.load_from_checkpoint("../finetuning/tabular_classification_model.pt") model.serializer = Labels(['Did not survive', 'Survived']) # 3. Generate predictions from a sheet file! Who would survive? -predictions = model.predict("data/titanic/titanic.csv") +predictions = model.predict("data/titanic/titanic.csv", data_source="csv") print(predictions) From fd076445567098e016e0985d19047fc972bca4fc Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 May 2021 17:59:43 +0100 Subject: [PATCH 23/78] Fixes --- flash/text/classification/data.py | 37 +++++++++++++++++++------- flash/vision/classification/data.py | 13 +++++---- flash/vision/detection/data.py | 41 +++++++++++++++++------------ 3 files changed, 58 insertions(+), 33 deletions(-) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 94d558b858..196c0a48a8 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -140,13 +140,34 @@ def load_data( class TextClassificationPreprocess(Preprocess): - data_sources = { - DefaultDataSources.CSV: TextCSVDataSource, - "sentences": TextSentencesDataSource, - } + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + backbone: str = "prajjwal1/bert-tiny", + max_length: int = 128, + ): + self.backbone = backbone + self.max_length = max_length + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: TextCSVDataSource(backbone=backbone, max_length=max_length), + "sentences": TextSentencesDataSource(backbone=backbone, max_length=max_length), + } + ) def get_state_dict(self) -> Dict[str, Any]: - return {} + return { + "backbone": self.backbone, + "max_length": self.max_length, + } @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): @@ -209,8 +230,6 @@ def from_csv( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - data_source_kwargs=dict( - backbone=backbone, - max_length=max_length, - ), + backbone=backbone, + max_length=max_length, ) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 6781fb54ed..dbfcbe97a0 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -37,13 +37,6 @@ class ImageClassificationPreprocess(Preprocess): - data_sources = { - DefaultDataSources.FOLDERS: ImageFoldersDataSource, - DefaultDataSources.FILES: ImageFilesDataSource, - DefaultDataSources.NUMPY: ImageNumpyDataSource, - DefaultDataSources.TENSOR: ImageTensorDataSource, - } - def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, @@ -59,6 +52,12 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, + data_sources={ + DefaultDataSources.FOLDERS: ImageFoldersDataSource(), + DefaultDataSources.FILES: ImageFilesDataSource(), + DefaultDataSources.NUMPY: ImageNumpyDataSource(), + DefaultDataSources.TENSOR: ImageTensorDataSource(), + } ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 4af7ed3de9..6c1b0e6a97 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, Dict, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union -from pytorch_lightning.trainer.states import RunningStage from torch.nn import Module from torchvision.datasets.folder import default_loader @@ -91,9 +90,22 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class ObjectDetectionPreprocess(Preprocess): - data_sources = { - "coco": COCODataSource, - } + def __init__( + self, + train_transform: Optional[Union[Dict[str, Callable]]] = None, + val_transform: Optional[Union[Dict[str, Callable]]] = None, + test_transform: Optional[Union[Dict[str, Callable]]] = None, + predict_transform: Optional[Union[Dict[str, Callable]]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + "coco": COCODataSource(), + } + ) def collate(self, samples: Any) -> Any: return {key: [sample[key] for sample in samples] for key in samples[0]} @@ -136,23 +148,18 @@ def from_coco( batch_size: int = 4, num_workers: Optional[int] = None, preprocess: Preprocess = None, - **kwargs + val_split: Optional[float] = None, ): - preprocess = preprocess or cls.preprocess_cls( - train_transform, - val_transform, - test_transform, - ) - - data_source = preprocess.data_source_of_type(COCODataSource)() - return cls.from_data_source( - data_source=data_source, + data_source="coco", train_data=(train_folder, train_ann_file) if train_folder else None, val_data=(val_folder, val_ann_file) if val_folder else None, test_data=(test_folder, test_ann_file) if test_folder else None, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + preprocess=preprocess, + val_split=val_split, batch_size=batch_size, num_workers=num_workers, - preprocess=preprocess, - **kwargs ) From 19e966d8cece25763810cda8324b11f8092a6d80 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 May 2021 18:07:02 +0100 Subject: [PATCH 24/78] Fixes --- .../image_classification_multi_label.py | 18 +++++++++--------- .../image_classification_multi_label.py | 8 +++----- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/flash_examples/finetuning/image_classification_multi_label.py b/flash_examples/finetuning/image_classification_multi_label.py index ca2360c519..5ae824d74f 100644 --- a/flash_examples/finetuning/image_classification_multi_label.py +++ b/flash_examples/finetuning/image_classification_multi_label.py @@ -39,16 +39,16 @@ def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], L [[int(row[genre]) for genre in genres] for _, row in metadata.iterrows()]) -train_filepaths, train_labels = load_data('train') -test_filepaths, test_labels = load_data('test') - -datamodule = ImageClassificationData.from_filepaths( - train_filepaths=train_filepaths, - train_labels=train_labels, - test_filepaths=test_filepaths, - test_labels=test_labels, - preprocess=ImageClassificationPreprocess(image_size=(128, 128)), +train_files, train_targets = load_data('train') +test_files, test_targets = load_data('test') + +datamodule = ImageClassificationData.from_files( + train_files=train_files, + train_targets=train_targets, + test_files=test_files, + test_targets=test_targets, val_split=0.1, # Use 10 % of the train dataset to generate validation one. + image_size=(128, 128), ) # 3. Build the model diff --git a/flash_examples/predict/image_classification_multi_label.py b/flash_examples/predict/image_classification_multi_label.py index c20f78172b..77b5f35978 100644 --- a/flash_examples/predict/image_classification_multi_label.py +++ b/flash_examples/predict/image_classification_multi_label.py @@ -33,16 +33,14 @@ class CustomViz(BaseVisualization): def show_per_batch_transform(self, batch: Any, _) -> None: - images = batch[0] + images = batch[0]["input"] image = make_grid(images, nrow=2) image = T.to_pil_image(image, 'RGB') image.show() # 3. Load the model from a checkpoint -model = ImageClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt", -) +model = ImageClassifier.load_from_checkpoint("../finetuning/image_classification_multi_label_model.pt", ) # 4a. Predict the genres of a few movie posters! predictions = model.predict([ @@ -56,7 +54,7 @@ def show_per_batch_transform(self, batch: Any, _) -> None: datamodule = ImageClassificationData.from_folders( predict_folder="data/movie_posters/predict/", data_fetcher=CustomViz(), - preprocess=model.preprocess, + image_size=(128, 128), ) predictions = Trainer().predict(model, datamodule=datamodule) From 3b7ab0e506dc3a7bd1b8b2fead4c809773dda671 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 10:00:24 +0100 Subject: [PATCH 25/78] Add summarization example --- flash/text/classification/data.py | 30 +-- flash/text/seq2seq/core/data.py | 299 ++++++++++----------- flash/text/seq2seq/summarization/data.py | 134 +-------- flash/text/seq2seq/summarization/metric.py | 6 +- flash/text/seq2seq/summarization/model.py | 4 +- flash_examples/finetuning/summarization.py | 6 +- flash_examples/predict/summarization.py | 53 ++-- 7 files changed, 192 insertions(+), 340 deletions(-) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 196c0a48a8..c3b1a4e6e7 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -29,18 +29,16 @@ class TextDataSource(DataSource): - def __init__(self, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): + def __init__(self, tokenizer, max_length: int = 128): super().__init__() - self.backbone = backbone + self.tokenizer = tokenizer self.max_length = max_length - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - def _tokenize_fn( self, ex: Union[Dict[str, str], str], - input: str = None, + input: Optional[str] = None, ) -> Callable: """This function is used to tokenize sentences using the provided tokenizer.""" if isinstance(ex, dict): @@ -54,8 +52,8 @@ def _transform_label(self, label_to_class_mapping: Dict[str, int], target: str, class TextFileDataSource(TextDataSource): - def __init__(self, filetype: str, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): - super().__init__(backbone=backbone, max_length=max_length) + def __init__(self, filetype: str, tokenizer, max_length: int = 128): + super().__init__(tokenizer, max_length=max_length) self.filetype = filetype @@ -112,20 +110,20 @@ def predict_load_data(self, data: Any, dataset: AutoDataset): class TextCSVDataSource(TextFileDataSource): - def __init__(self, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): - super().__init__("csv", backbone=backbone, max_length=max_length) + def __init__(self, tokenizer, max_length: int = 128): + super().__init__("csv", tokenizer, max_length=max_length) class TextJSONDataSource(TextFileDataSource): - def __init__(self, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): - super().__init__("json", backbone=backbone, max_length=max_length) + def __init__(self, tokenizer, max_length: int = 128): + super().__init__("json", tokenizer, max_length=max_length) class TextSentencesDataSource(TextDataSource): - def __init__(self, backbone: str = "prajjwal1/bert-tiny", max_length: int = 128): - super().__init__(backbone=backbone, max_length=max_length) + def __init__(self, tokenizer, max_length: int = 128): + super().__init__(tokenizer, max_length=max_length) def load_data( self, @@ -152,14 +150,16 @@ def __init__( self.backbone = backbone self.max_length = max_length + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) + super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.CSV: TextCSVDataSource(backbone=backbone, max_length=max_length), - "sentences": TextSentencesDataSource(backbone=backbone, max_length=max_length), + DefaultDataSources.CSV: TextCSVDataSource(self.tokenizer, max_length=max_length), + "sentences": TextSentencesDataSource(self.tokenizer, max_length=max_length), } ) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 883fd9c8bc..410f3c58d4 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -18,84 +18,73 @@ import datasets import torch from datasets import DatasetDict, load_dataset -from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from transformers import AutoTokenizer, default_data_collator from flash.data.data_module import DataModule -from flash.data.process import Postprocess, Preprocess +from flash.data.data_source import DataSource, DefaultDataSources +from flash.data.process import Preprocess -class Seq2SeqPreprocess(Preprocess): +class Seq2SeqDataSource(DataSource): def __init__( self, tokenizer, - input: str, - filetype: str, - target: Optional[str] = None, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'longest' + padding: Union[str, bool] = 'max_length' ): super().__init__() self.tokenizer = tokenizer - self.input = input - self.filetype = filetype - self.target = target - self.max_target_length = max_target_length self.max_source_length = max_source_length + self.max_target_length = max_target_length self.padding = padding - self._tokenize_fn = partial( - self._tokenize_fn, - tokenizer=self.tokenizer, - input=self.input, - target=self.target, - max_source_length=self.max_source_length, + + def _tokenize_fn( + self, + ex: Union[Dict[str, str], str], + input: Optional[str] = None, + target: Optional[str] = None, + ) -> Callable: + if isinstance(ex, dict): + ex_input = ex[input] + ex_target = ex[target] if target else None + else: + ex_input = ex + ex_target = None + + return self.tokenizer.prepare_seq2seq_batch( + src_texts=ex_input, + tgt_texts=ex_target, + max_length=self.max_source_length, max_target_length=self.max_target_length, - padding=self.padding + padding=self.padding, ) - def get_state_dict(self) -> Dict[str, Any]: - return { - "input": self.input, - "backbone": self.backbone, - "max_length": self.max_length, - "target": self.target, - "filetype": self.filetype, - "label_to_class_mapping": self.label_to_class_mapping, - } - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return cls(**state_dict) +class Seq2SeqFileDataSource(Seq2SeqDataSource): - @staticmethod - def _tokenize_fn( - ex, + def __init__( + self, + filetype: str, tokenizer, - input: str, - target: Optional[str], - max_source_length: int, - max_target_length: int, - padding: Union[str, bool], - ) -> Callable: - output = tokenizer.prepare_seq2seq_batch( - src_texts=ex[input], - tgt_texts=ex[target] if target else None, - max_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - ) - return output + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length', + ): + super().__init__(tokenizer, max_source_length, max_target_length, padding) + + self.filetype = filetype def load_data( self, - file: str, - use_full: bool = True, + data: Any, + use_full: bool = False, columns: List[str] = ["input_ids", "attention_mask", "labels"] ) -> 'datasets.Dataset': + file, input, target = data data_files = {} stage = self._running_stage.value data_files[stage] = str(file) @@ -112,140 +101,122 @@ def load_data( except AssertionError: dataset_dict = load_dataset(self.filetype, data_files=data_files) - dataset_dict = dataset_dict.map(self._tokenize_fn, batched=True) + dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True) dataset_dict.set_format(columns=columns) return dataset_dict[stage] - def predict_load_data(self, sample: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]: - if isinstance(sample, str) and os.path.isfile(sample) and sample.endswith(".csv"): - return self.load_data(sample, use_full=True, columns=["input_ids", "attention_mask"]) - else: - if isinstance(sample, (list, tuple)) and len(sample) > 0 and all(isinstance(s, str) for s in sample): - return [self._tokenize_fn({self.input: s, self.target: None}) for s in sample] - else: - raise MisconfigurationException("Currently, we support only list of sentences") + def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]: + return self.load_data(data, use_full=False, columns=["input_ids", "attention_mask"]) - def collate(self, samples: Any) -> Tensor: - """Override to convert a set of samples to a batch""" - return default_data_collator(samples) +class Seq2SeqCSVDataSource(Seq2SeqFileDataSource): -class Seq2SeqData(DataModule): - """Data module for Seq2Seq tasks.""" - - preprocess_cls = Seq2SeqPreprocess - - @classmethod - def from_files( - cls, - train_file: Optional[str], - input: str = 'input', - target: Optional[str] = None, - filetype: str = "csv", - backbone: str = "sshleifer/tiny-mbart", - val_file: Optional[str] = None, - test_file: Optional[str] = None, - predict_file: Optional[str] = None, + def __init__( + self, + tokenizer, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', - batch_size: int = 32, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, ): - """Creates a Seq2SeqData object from files. - Args: - train_file: Path to training data. - input: The field storing the source translation text. - target: The field storing the target translation text. - filetype: ``csv`` or ``json`` File - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - val_file: Path to validation data. - test_file: Path to test data. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 32. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - Returns: - Seq2SeqData: The constructed data module. - Examples:: - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, - target="fraud", - num_cols=["account_value"], - cat_cols=["account_type"]) - """ - tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - preprocess = preprocess or cls.preprocess_cls( + super().__init__( + "csv", tokenizer, - input, - filetype, - target, - max_source_length, - max_target_length, - padding, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, ) - return cls.from_load_data_inputs( - train_load_data_input=train_file, - val_load_data_input=val_file, - test_load_data_input=test_file, - predict_load_data_input=predict_file, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) - @classmethod - def from_file( - cls, - predict_file: str, - input: str = 'input', - target: Optional[str] = None, - backbone: str = "sshleifer/tiny-mbart", - filetype: str = "csv", +class Seq2SeqJSONDataSource(Seq2SeqFileDataSource): + + def __init__( + self, + tokenizer, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', - batch_size: int = 32, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, ): - """Creates a TextClassificationData object from files. - Args: - predict_file: Path to prediction input file. - input: The field storing the source translation text. - target: The field storing the target translation text. - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - filetype: Csv or json. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 32. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - Returns: - Seq2SeqData: The constructed data module. - """ - return cls.from_files( - train_file=None, - input=input, - target=target, - filetype=filetype, - backbone=backbone, - predict_file=predict_file, + super().__init__( + "json", + tokenizer, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, ) + + +class Seq2SeqSentencesDataSource(Seq2SeqDataSource): + + def load_data( + self, + data: Union[str, List[str]], + dataset: Optional[Any] = None, + ) -> List[Any]: + + if isinstance(data, str): + data = [data] + return [self._tokenize_fn(s) for s in data] + + +class Seq2SeqPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + backbone: str = "sshleifer/tiny-mbart", + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length' + ): + self.backbone = backbone + self.max_target_length = max_target_length + self.max_source_length = max_source_length + self.padding = padding + + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: Seq2SeqCSVDataSource( + self.tokenizer, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ), + "sentences": Seq2SeqSentencesDataSource( + self.tokenizer, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ), + }, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return { + "backbone": self.backbone, + "max_source_length": self.max_source_length, + "max_target_length": self.max_target_length, + "padding": self.padding, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + return cls(**state_dict) + + def collate(self, samples: Any) -> Tensor: + """Override to convert a set of samples to a batch""" + return default_data_collator(samples) + + +class Seq2SeqData(DataModule): + """Data module for Seq2Seq tasks.""" + + preprocess_cls = Seq2SeqPreprocess diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index bcdd2a2ff6..791c98a32f 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Type, Union +from typing import Any from transformers import AutoTokenizer -from flash.data.process import Postprocess, Preprocess +from flash.data.process import Postprocess from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess @@ -23,10 +23,12 @@ class SummarizationPostprocess(Postprocess): def __init__( self, - tokenizer: AutoTokenizer, + backbone: str = "sshleifer/tiny-mbart", ): super().__init__() - self.tokenizer = tokenizer + + # TODO: Should share the backbone or tokenizer over state + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) def uncollate(self, generated_tokens: Any) -> Any: pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) @@ -38,127 +40,3 @@ class SummarizationData(Seq2SeqData): preprocess_cls = Seq2SeqPreprocess postprocess_cls = SummarizationPostprocess - - @classmethod - def from_files( - cls, - train_file: Optional[str] = None, - input: str = 'input', - target: Optional[str] = None, - filetype: str = "csv", - backbone: str = "t5-small", - val_file: str = None, - test_file: str = None, - predict_file: str = None, - max_source_length: int = 512, - max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', - batch_size: int = 16, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ): - """Creates a SummarizationData object from files. - - Args: - train_file: Path to training data. - input: The field storing the source translation text. - target: The field storing the target translation text. - filetype: .csv or .json - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - val_file: Path to validation data. - test_file: Path to test data. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 16. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - - Returns: - SummarizationData: The constructed data module. - - Examples:: - - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, target="fraud", - num_cols=["account_value"], - cat_cols=["account_type"]) - - """ - tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - - preprocess = preprocess or cls.preprocess_cls( - tokenizer, - input, - filetype, - target, - max_source_length, - max_target_length, - padding, - ) - - postprocess = postprocess or cls.postprocess_cls(tokenizer) - - return cls.from_load_data_inputs( - train_load_data_input=train_file, - val_load_data_input=val_file, - test_load_data_input=test_file, - predict_load_data_input=predict_file, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) - - @classmethod - def from_file( - cls, - predict_file: str, - input: str = 'src_text', - target: Optional[str] = None, - backbone: str = "t5-small", - filetype: str = "csv", - max_source_length: int = 512, - max_target_length: int = 128, - padding: Union[str, bool] = 'longest', - batch_size: int = 16, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ): - """Creates a SummarizationData object from files. - - Args: - predict_file: Path to prediction input file. - input: The field storing the source translation text. - target: The field storing the target translation text. - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - filetype: csv or json. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 16. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - - Returns: - SummarizationData: The constructed data module. - - """ - return super().from_file( - predict_file=predict_file, - input=input, - target=target, - backbone=backbone, - filetype=filetype, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) diff --git a/flash/text/seq2seq/summarization/metric.py b/flash/text/seq2seq/summarization/metric.py index c8cbff6d14..694f0d5763 100644 --- a/flash/text/seq2seq/summarization/metric.py +++ b/flash/text/seq2seq/summarization/metric.py @@ -19,7 +19,7 @@ from torch import tensor from torchmetrics import Metric -from flash.text.seq2seq import summarization +from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence class RougeMetric(Metric): @@ -67,8 +67,8 @@ def update(self, pred_lns: List[str], tgt_lns: List[str]): for pred, tgt in zip(pred_lns, tgt_lns): # rougeLsum expects "\n" separated sentences within a summary if self.rouge_newline_sep: - pred = summarization.utils.add_newline_to_end_of_each_sentence(pred) - tgt = summarization.utils.add_newline_to_end_of_each_sentence(tgt) + pred = add_newline_to_end_of_each_sentence(pred) + tgt = add_newline_to_end_of_each_sentence(tgt) results = self.scorer.score(pred, tgt) for key, score in results.items(): score = tensor([score.precision, score.recall, score.fmeasure]) diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index a3c9142bb5..04e763780b 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union +from typing import Callable, Dict, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl import torch @@ -37,7 +37,7 @@ class SummarizationTask(Seq2SeqTask): def __init__( self, - backbone: str = "t5-small", + backbone: str = "sshleifer/tiny-mbart", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, diff --git a/flash_examples/finetuning/summarization.py b/flash_examples/finetuning/summarization.py index d2ecc726f3..78757ebacc 100644 --- a/flash_examples/finetuning/summarization.py +++ b/flash_examples/finetuning/summarization.py @@ -21,12 +21,12 @@ download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the data -datamodule = SummarizationData.from_files( +datamodule = SummarizationData.from_csv( + "input", + "target", train_file="data/xsum/train.csv", val_file="data/xsum/valid.csv", test_file="data/xsum/test.csv", - input="input", - target="target" ) # 3. Build the model diff --git a/flash_examples/predict/summarization.py b/flash_examples/predict/summarization.py index 6d16ebfcaf..eb3358b1f4 100644 --- a/flash_examples/predict/summarization.py +++ b/flash_examples/predict/summarization.py @@ -20,37 +20,40 @@ download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the model from a checkpoint -model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") +model = SummarizationTask.load_from_checkpoint("../finetuning/summarization_model_xsum.pt") # 2a. Summarize an article! -predictions = model.predict([ - """ - Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local - people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. - They came to Brixton to see work which has started to revitalise the borough. - It was Charles' first visit to the area since 1996, when he was accompanied by the former - South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue - for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. - ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. - She asked me were they ripe and I said yes - they're from the Dominican Republic."" - Mr Chong is one of 170 local retailers who accept the Brixton Pound. - Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market - or in participating shops. - During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children - nearby on an estate off Coldharbour Lane. Mr West said: - ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" - He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" - Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. - The trust hopes to restore and refurbish the building, - where once Jimi Hendrix and The Clash played, as a new community and business centre." - """ -]) +predictions = model.predict( + [ + """ + Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local + people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. + They came to Brixton to see work which has started to revitalise the borough. + It was Charles' first visit to the area since 1996, when he was accompanied by the former + South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue + for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. + ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. + She asked me were they ripe and I said yes - they're from the Dominican Republic."" + Mr Chong is one of 170 local retailers who accept the Brixton Pound. + Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market + or in participating shops. + During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children + nearby on an estate off Coldharbour Lane. Mr West said: + ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" + He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" + Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. + The trust hopes to restore and refurbish the building, + where once Jimi Hendrix and The Clash played, as a new community and business centre." + """ + ], + data_source="sentences", +) print(predictions) # 2b. Or generate summaries from a sheet file! -datamodule = SummarizationData.from_files( +datamodule = SummarizationData.from_csv( + "input", predict_file="data/xsum/predict.csv", - input="input", ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) From d9c00c5760359ad674965e32ff4210ffe8a51b63 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 10:13:20 +0100 Subject: [PATCH 26/78] Add translation --- flash/text/seq2seq/translation/data.py | 130 ++++------------------- flash_examples/finetuning/translation.py | 8 +- flash_examples/predict/translation.py | 17 +-- 3 files changed, 33 insertions(+), 122 deletions(-) diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 940bae7af8..32a91746b8 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -11,129 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union -from flash.data.process import Postprocess, Preprocess -from flash.text.seq2seq.core.data import Seq2SeqData +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess -class TranslationData(Seq2SeqData): - """Data module for Translation tasks.""" +class TranslationPreprocess(Seq2SeqPreprocess): - @classmethod - def from_files( - cls, - train_file, - input: str = 'input', - target: Optional[str] = None, - filetype="csv", - backbone="facebook/mbart-large-en-ro", - val_file=None, - test_file=None, - predict_file=None, + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + backbone: str = "facebook/mbart-large-en-ro", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', - batch_size: int = 8, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, + padding: Union[str, bool] = 'max_length' ): - """Creates a TranslateData object from files. - - Args: - train_file: Path to training data. - input: The field storing the source translation text. - target: The field storing the target translation text. - filetype: .csv or .json - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - val_file: Path to validation data. - test_file: Path to test data. - predict_file: Path to predict data. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 8. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - - Returns: - TranslateData: The constructed data module. - - Examples:: - - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, target="fraud", - num_cols=["account_value"], - cat_cols=["account_type"]) - - """ - return super().from_files( - train_file=train_file, - val_file=val_file, - test_file=test_file, - predict_file=predict_file, - input=input, - target=target, + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, backbone=backbone, - filetype=filetype, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, ) - @classmethod - def from_file( - cls, - predict_file: str, - input: str = 'input', - target: Optional[str] = None, - backbone="facebook/mbart-large-en-ro", - filetype="csv", - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = 'longest', - batch_size: int = 8, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ): - """Creates a TranslationData object from files. - Args: - predict_file: Path to prediction input file. - input: The field storing the source translation text. - target: The field storing the target translation text. - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - filetype: csv or json. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 8. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - - - Returns: - Seq2SeqData: The constructed data module. +class TranslationData(Seq2SeqData): + """Data module for Translation tasks.""" - """ - return super().from_file( - predict_file=predict_file, - input=input, - target=target, - backbone=backbone, - filetype=filetype, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) + preprocess_cls = TranslationPreprocess diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index be91ea057d..69440bed66 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -21,13 +21,13 @@ download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") # 2. Load the data -datamodule = TranslationData.from_files( +datamodule = TranslationData.from_csv( + "input", + "target", train_file="data/wmt_en_ro/train.csv", val_file="data/wmt_en_ro/valid.csv", test_file="data/wmt_en_ro/test.csv", - input="input", - target="target", - batch_size=1 + batch_size=1, ) # 3. Build the model diff --git a/flash_examples/predict/translation.py b/flash_examples/predict/translation.py index bbf3d42446..529ccf8ea1 100644 --- a/flash_examples/predict/translation.py +++ b/flash_examples/predict/translation.py @@ -20,19 +20,22 @@ download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") # 2. Load the model from a checkpoint -model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") +model = TranslationTask.load_from_checkpoint("../finetuning/translation_model_en_ro.pt") # 2a. Translate a few sentences! -predictions = model.predict([ - "BBC News went to meet one of the project's first graduates.", - "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", -]) +predictions = model.predict( + [ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", + ], + data_source="sentences", +) print(predictions) # 2b. Or generate translations from a sheet file! -datamodule = TranslationData.from_file( +datamodule = TranslationData.from_csv( + "input", predict_file="data/wmt_en_ro/predict.csv", - input="input", ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) From d5b8c4a230b11b42ec7af503240f5f80344afe5d Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 12:18:35 +0200 Subject: [PATCH 27/78] assert empty data_source in datapipeline creation --- flash/core/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flash/core/model.py b/flash/core/model.py index 3a2bccd345..7d063a870a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -293,6 +293,10 @@ def build_data_pipeline( preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) + else: + # TODO: we should log with low severity level that we use defaults to create + # `preprocess`, `postprocess` and `serializer`. + pass # Defaults / task attributes preprocess, postprocess, serializer = Task._resolve( @@ -318,6 +322,8 @@ def build_data_pipeline( data_source = data_source or old_data_source if isinstance(data_source, str): + assert preprocess is not None, type(preprocess) + # TODO: somehow the preprocess is not well generated when is a Default type data_source = preprocess.data_source_of_name(data_source) data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) From dd35da62945015ec858bdd058efcd51be10bb026 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 13:32:59 +0200 Subject: [PATCH 28/78] add more assertions for test_classification_task_predict_folder_path --- flash/core/model.py | 2 ++ flash/data/data_source.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/flash/core/model.py b/flash/core/model.py index 7d063a870a..485ee191d6 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -104,6 +104,7 @@ def __init__( self._postprocess: Optional[Postprocess] = postprocess self._serializer: Optional[Serializer] = None + # TODO: create enum values to define what are the exact states self._data_pipeline_state: Optional[DataPipelineState] = None # Explicitly set the serializer to call the setter @@ -176,6 +177,7 @@ def predict( data_pipeline = self.build_data_pipeline(data_source, data_pipeline) x = [x for x in data_pipeline._data_source.generate_dataset(x, running_stage)] + assert len(x) > 0, "List of inputs shouldn't be empty." x = data_pipeline.worker_preprocessor(running_stage)(x) # switch to self.device when #7188 merge in Lightning x = self.transfer_batch_to_device(x, next(self.parameters()).device) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index fac7f77856..506f54a153 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -97,6 +97,9 @@ def generate_dataset( running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: is_none = data is None + # TODO: we should parse better the possible data types here. + # Are `pata_paths` considered as Sequence ? for now it pass + # the statement found in below. if isinstance(data, Sequence): is_none = data[0] is None @@ -177,6 +180,7 @@ def load_data(self, data: str, dataset: Optional[Any] = None) -> Iterable[Mappin classes, class_to_idx = self.find_classes(data) if not classes: files = [os.path.join(data, file) for file in os.listdir(data)] + assert len(files) > 0, "Files list shouldn't be empty." return [{ DefaultDataKeys.INPUT: file } for file in filter( From f2c3f20f8ca7c4a89e9fa2a5aa21db3c8f264101 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 12:40:24 +0100 Subject: [PATCH 29/78] Add video --- flash/core/model.py | 2 +- flash/data/data_module.py | 4 +- flash/data/data_source.py | 94 +++---- flash/video/classification/data.py | 255 ++++++------------ flash/vision/classification/data.py | 5 +- flash/vision/data.py | 36 +-- .../finetuning/video_classification.py | 23 +- .../predict/video_classification.py | 19 +- 8 files changed, 159 insertions(+), 279 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 3a2bccd345..9642c09953 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -157,7 +157,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: def predict( self, x: Any, - data_source: Union[str, DataSource] = DefaultDataSources.FILES, + data_source: Union[str, DataSource] = DefaultDataSources.PATHS, data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ diff --git a/flash/data/data_module.py b/flash/data/data_module.py index e4b435a639..9fce1d4ed1 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -401,7 +401,7 @@ def from_folders( **preprocess_kwargs: Any, ) -> 'DataModule': return cls.from_data_source( - DefaultDataSources.FOLDERS, + DefaultDataSources.PATHS, train_folder, val_folder, test_folder, @@ -440,7 +440,7 @@ def from_files( **preprocess_kwargs: Any, ) -> 'DataModule': return cls.from_data_source( - DefaultDataSources.FILES, + DefaultDataSources.PATHS, (train_files, train_targets), (val_files, val_targets), (test_files, test_targets), diff --git a/flash/data/data_source.py b/flash/data/data_source.py index fac7f77856..13066655f5 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -129,8 +129,7 @@ def generate_dataset( class DefaultDataSources(LightningEnum): - FOLDERS = "folders" - FILES = "files" + PATHS = "paths" NUMPY = "numpy" TENSOR = "tensor" CSV = "csv" @@ -150,46 +149,6 @@ def __hash__(self) -> int: return hash(self.value) -class FoldersDataSource(DataSource[str]): - - def __init__(self, extensions: Optional[Tuple[str, ...]] = None): - super().__init__() - - self.extensions = extensions - - @staticmethod - def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: - """ - Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. - - Args: - dir: Root directory path. - - Returns: - tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - """ - classes = [d.name for d in os.scandir(dir) if d.is_dir()] - classes.sort() - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - return classes, class_to_idx - - def load_data(self, data: str, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: - classes, class_to_idx = self.find_classes(data) - if not classes: - files = [os.path.join(data, file) for file in os.listdir(data)] - return [{ - DefaultDataKeys.INPUT: file - } for file in filter( - lambda file: has_file_allowed_extension(file, self.extensions), - files, - )] - else: - self.set_state(LabelsState(classes)) - dataset.num_classes = len(classes) - data = make_dataset(data, class_to_idx, extensions=self.extensions) - return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] - - SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE") @@ -224,18 +183,44 @@ def predict_load_data(self, data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapp return [{DefaultDataKeys.INPUT: input} for input in data] -class FilesDataSource(SequenceDataSource[str]): +class PathsDataSource(SequenceDataSource): # TODO: Sort out the typing here - def __init__(self, extensions: Optional[Tuple[str, ...]] = None, labels: Optional[Sequence[str]] = None): - super().__init__(labels=labels) + def __init__(self, extensions: Optional[Tuple[str, ...]] = None): + super().__init__() self.extensions = extensions - def load_data( - self, - data: Tuple[Sequence[str], Optional[Sequence[Any]]], - dataset: Optional[Any] = None, - ) -> Sequence[Mapping[str, Any]]: + @staticmethod + def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: + """ + Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. + + Args: + dir: Root directory path. + + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + """ + classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes.sort() + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + def load_data(self, + data: Union[str, Tuple[List[str], List[Any]]], + dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + if isinstance(data, str) and os.path.isdir(data): + classes, class_to_idx = self.find_classes(data) + if not classes: + return self.predict_load_data(data) + else: + self.set_state(LabelsState(classes)) + + if dataset is not None: + dataset.num_classes = len(classes) + + data = make_dataset(data, class_to_idx, extensions=self.extensions) + return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] return list( filter( lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), @@ -243,7 +228,14 @@ def load_data( ) ) - def predict_load_data(self, data: Sequence[str]) -> Sequence[Mapping[str, Any]]: + def predict_load_data(self, + data: Union[str, List[str]], + dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + if isinstance(data, str): + if os.path.isdir(data): + data = [os.path.join(data, file) for file in os.listdir(data)] + else: + data = [data] return list( filter( lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 3bac7e92ed..02ec133308 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import pathlib from typing import Any, Callable, Dict, List, Optional, Type, Union @@ -19,18 +18,15 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import RandomSampler, Sampler -from torch.utils.data.dataset import IterableDataset -from flash.core.classification import ClassificationState from flash.data.data_module import DataModule +from flash.data.data_source import DefaultDataKeys, DefaultDataSources, LabelsState, PathsDataSource from flash.data.process import Preprocess from flash.utils.imports import _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE if _KORNIA_AVAILABLE: import kornia.augmentation as K - import kornia.geometry.transform as T -else: - from torchvision import transforms as T + if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.data.clip_sampling import ClipSampler, make_clip_sampler from pytorchvideo.data.encoded_video import EncodedVideo @@ -43,75 +39,24 @@ _PYTORCHVIDEO_DATA = Dict[str, Union[str, torch.Tensor, int, float, List]] -class VideoClassificationPreprocess(Preprocess): - - EXTENSIONS = ("mp4", "avi") - - @staticmethod - def default_predict_transform() -> Dict[str, 'Compose']: - return { - "post_tensor_transform": Compose([ - ApplyTransformToKey( - key="video", - transform=Compose([ - UniformTemporalSubsample(8), - RandomShortSideScale(min_size=256, max_size=320), - RandomCrop(244), - RandomHorizontalFlip(p=0.5), - ]), - ), - ]), - "per_batch_transform_on_device": Compose([ - ApplyTransformToKey( - key="video", - transform=K.VideoSequential( - K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), - data_format="BCTHW", - same_on_frame=False - ) - ), - ]), - } +class VideoClassificationPathsDataSource(PathsDataSource): def __init__( self, clip_sampler: 'ClipSampler', - video_sampler: Type[Sampler], - decode_audio: bool, - decoder: str, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = True, + decoder: str = "pyav", ): - # Make sure to provide your transform to the Preprocess Class - super().__init__( - train_transform, val_transform, test_transform, predict_transform or self.default_predict_transform() - ) + super().__init__(extensions=("mp4", "avi")) self.clip_sampler = clip_sampler self.video_sampler = video_sampler self.decode_audio = decode_audio self.decoder = decoder - def get_state_dict(self) -> Dict[str, Any]: - return { - 'clip_sampler': self.clip_sampler, - 'video_sampler': self.video_sampler, - 'decode_audio': self.decode_audio, - 'decoder': self.decoder, - 'train_transform': self._train_transform, - 'val_transform': self._val_transform, - 'test_transform': self._test_transform, - 'predict_transform': self._predict_transform, - } - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> 'VideoClassificationPreprocess': - return cls(**state_dict) - - def load_data(self, data: Any, dataset: IterableDataset) -> 'EncodedVideoDataset': + def load_data(self, data: str, dataset: Optional[Any] = None) -> 'EncodedVideoDataset': ds: EncodedVideoDataset = labeled_encoded_video_dataset( - data, + pathlib.Path(data), self.clip_sampler, video_sampler=self.video_sampler, decode_audio=self.decode_audio, @@ -119,21 +64,10 @@ def load_data(self, data: Any, dataset: IterableDataset) -> 'EncodedVideoDataset ) if self.training: label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels} - self.set_state(ClassificationState(label_to_class_mapping)) + self.set_state(LabelsState(label_to_class_mapping)) dataset.num_classes = len(np.unique([s[1]['label'] for s in ds._labeled_videos])) return ds - def predict_load_data(self, folder_or_file: Union[str, List[str]]) -> List[str]: - if isinstance(folder_or_file, list) and all(os.path.exists(p) for p in folder_or_file): - return folder_or_file - elif os.path.isdir(folder_or_file): - return [f for f in os.listdir(folder_or_file) if f.lower().endswith(self.EXTENSIONS)] - elif os.path.exists(folder_or_file) and folder_or_file.lower().endswith(self.EXTENSIONS): - return [folder_or_file] - raise MisconfigurationException( - f"The provided predict output should be a folder or a path. Found: {folder_or_file}" - ) - def _encoded_video_to_dict(self, video) -> Dict[str, Any]: ( clip_start, @@ -167,91 +101,32 @@ def _encoded_video_to_dict(self, video) -> Dict[str, Any]: } if audio_samples is not None else {}), } - def predict_load_sample(self, video_path: str) -> "EncodedVideo": - return self._encoded_video_to_dict(EncodedVideo.from_path(video_path)) - - def pre_tensor_transform(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) - - def to_tensor_transform(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) - - def post_tensor_transform(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) - - def per_batch_transform(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) + def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + return self._encoded_video_to_dict(EncodedVideo.from_path(sample[DefaultDataKeys.INPUT])) - def per_batch_transform_on_device(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) +class VideoClassificationPreprocess(Preprocess): -class VideoClassificationData(DataModule): - """Data module for Video classification tasks.""" - - preprocess_cls = VideoClassificationPreprocess - - @classmethod - def from_paths( - cls, - train_data_path: Optional[Union[str, pathlib.Path]] = None, - val_data_path: Optional[Union[str, pathlib.Path]] = None, - test_data_path: Optional[Union[str, pathlib.Path]] = None, - predict_data_path: Union[str, pathlib.Path] = None, + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, clip_sampler: Union[str, 'ClipSampler'] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, - video_sampler: Type[Sampler] = RandomSampler, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - **kwargs, - ) -> 'DataModule': - """ - - Creates a VideoClassificationData object from folders of videos arranged in this way: :: - - train/class_x/xxx.ext - train/class_x/xxy.ext - train/class_x/xxz.ext - train/class_y/123.ext - train/class_y/nsdf3.ext - train/class_y/asd932_.ext - - Args: - train_data_path: Path to training folder. Default: None. - val_data_path: Path to validation folder. Default: None. - test_data_path: Path to test folder. Default: None. - predict_data_path: Path to predict folder. Default: None. - clip_sampler: ClipSampler to be used on videos. - clip_duration: Clip duration for the clip sampler. - clip_sampler_kwargs: Extra ClipSampler keyword arguments. - video_sampler: Sampler for the internal video container. - This defines the order videos are decoded and, if necessary, the distributed split. - decode_audio: Whether to decode the audio with the video clip. - decoder: Defines what type of decoder used to decode a video. - train_transform: Video clip dictionary transform to use for training set. - val_transform: Video clip dictionary transform to use for validation set. - test_transform: Video clip dictionary transform to use for test set. - predict_transform: Video clip dictionary transform to use for predict set. - batch_size: Batch size for data loading. - num_workers: The number of workers to use for parallelized loading. - Defaults to ``None`` which equals the number of available CPU threads. - preprocess: VideoClassifierPreprocess to handle the data processing. - - Returns: - VideoClassificationData: the constructed data module - - Examples: - >>> videos = VideoClassificationData.from_paths("train/") # doctest: +SKIP + ): + self.clip_sampler = clip_sampler + self.clip_duration = clip_duration + self.clip_sampler_kwargs = clip_sampler_kwargs + self.video_sampler = video_sampler + self.decode_audio = decode_audio + self.decoder = decoder - """ if not _PYTORCHVIDEO_AVAILABLE: raise ModuleNotFoundError("Please, run `pip install pytorchvideo`.") @@ -265,19 +140,67 @@ def from_paths( clip_sampler = make_clip_sampler(clip_sampler, clip_duration, **clip_sampler_kwargs) - preprocess: Preprocess = preprocess or cls.preprocess_cls( - clip_sampler, video_sampler, decode_audio, decoder, train_transform, val_transform, test_transform, - predict_transform + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.PATHS: VideoClassificationPathsDataSource( + clip_sampler, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + } ) - return cls.from_load_data_inputs( - train_load_data_input=train_data_path, - val_load_data_input=val_data_path, - test_load_data_input=test_data_path, - predict_load_data_input=predict_data_path, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - use_iterable_auto_dataset=True, - **kwargs, - ) + def get_state_dict(self) -> Dict[str, Any]: + return { + 'train_transform': self._train_transform, + 'val_transform': self._val_transform, + 'test_transform': self._test_transform, + 'predict_transform': self._predict_transform, + 'clip_sampler': self.clip_sampler, + 'clip_duration': self.clip_duration, + 'clip_sampler_kwargs': self.clip_sampler_kwargs, + 'video_sampler': self.video_sampler, + 'decode_audio': self.decode_audio, + 'decoder': self.decoder, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> 'VideoClassificationPreprocess': + return cls(**state_dict) + + @staticmethod + def default_predict_transform() -> Dict[str, 'Compose']: + return { + "post_tensor_transform": Compose([ + ApplyTransformToKey( + key="video", + transform=Compose([ + UniformTemporalSubsample(8), + RandomShortSideScale(min_size=256, max_size=320), + RandomCrop(244), + RandomHorizontalFlip(p=0.5), + ]), + ), + ]), + "per_batch_transform_on_device": Compose([ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + data_format="BCTHW", + same_on_frame=False + ) + ), + ]), + } + + +class VideoClassificationData(DataModule): + """Data module for Video classification tasks.""" + + preprocess_cls = VideoClassificationPreprocess diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index dbfcbe97a0..65766e5e23 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -27,7 +27,7 @@ from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE from flash.vision.classification.transforms import default_train_transforms, default_val_transforms -from flash.vision.data import ImageFilesDataSource, ImageFoldersDataSource, ImageNumpyDataSource, ImageTensorDataSource +from flash.vision.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -53,8 +53,7 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.FOLDERS: ImageFoldersDataSource(), - DefaultDataSources.FILES: ImageFilesDataSource(), + DefaultDataSources.PATHS: ImagePathsDataSource(), DefaultDataSources.NUMPY: ImageNumpyDataSource(), DefaultDataSources.TENSOR: ImageTensorDataSource(), } diff --git a/flash/vision/data.py b/flash/vision/data.py index 9d32cd603a..056e856468 100644 --- a/flash/vision/data.py +++ b/flash/vision/data.py @@ -11,54 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Mapping, Optional +from typing import Any, Dict, Optional import torch from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS from torchvision.transforms.functional import to_pil_image -from flash.data.data_source import ( - DefaultDataKeys, - FilesDataSource, - FoldersDataSource, - NumpyDataSource, - TensorDataSource, -) +from flash.data.data_source import DefaultDataKeys, NumpyDataSource, PathsDataSource, TensorDataSource -class ImageFoldersDataSource(FoldersDataSource): +class ImagePathsDataSource(PathsDataSource): def __init__(self): super().__init__(extensions=IMG_EXTENSIONS) - def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: - result = {} # TODO: this is required to avoid a memory leak, can we automate this? - result.update(sample) - result[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) - return result - - -class ImageFilesDataSource(FilesDataSource): - - def __init__(self): - super().__init__(extensions=IMG_EXTENSIONS) - - def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: - result = {} # TODO: this is required to avoid a memory leak, can we automate this? - result.update(sample) - result[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) - return result + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + sample[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) + return sample class ImageTensorDataSource(TensorDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Any: + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: sample[DefaultDataKeys.INPUT] = to_pil_image(sample[DefaultDataKeys.INPUT]) return sample class ImageNumpyDataSource(NumpyDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Any: + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: sample[DefaultDataKeys.INPUT] = to_pil_image(torch.from_numpy(sample[DefaultDataKeys.INPUT])) return sample diff --git a/flash_examples/finetuning/video_classification.py b/flash_examples/finetuning/video_classification.py index 0e30141a61..4efa815dee 100644 --- a/flash_examples/finetuning/video_classification.py +++ b/flash_examples/finetuning/video_classification.py @@ -35,7 +35,7 @@ if __name__ == '__main__': - _PATH_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + _PATH_ROOT = os.path.dirname(os.path.abspath(__file__)) # 1. Download a video clip dataset. Find more dataset at https://pytorchvideo.readthedocs.io/en/latest/data.html download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip") @@ -72,19 +72,19 @@ def make_transform( } # 3. Load the data from directories. - datamodule = VideoClassificationData.from_paths( - train_data_path=os.path.join(_PATH_ROOT, "data/kinetics/train"), - val_data_path=os.path.join(_PATH_ROOT, "data/kinetics/val"), - predict_data_path=os.path.join(_PATH_ROOT, "data/kinetics/predict"), - clip_sampler="uniform", - clip_duration=2, - video_sampler=RandomSampler, - decode_audio=False, + datamodule = VideoClassificationData.from_folders( + train_folder=os.path.join(_PATH_ROOT, "data/kinetics/train"), + val_folder=os.path.join(_PATH_ROOT, "data/kinetics/val"), + predict_folder=os.path.join(_PATH_ROOT, "data/kinetics/predict"), train_transform=make_transform(train_post_tensor_transform), val_transform=make_transform(val_post_tensor_transform), predict_transform=make_transform(val_post_tensor_transform), num_workers=8, batch_size=8, + clip_sampler="uniform", + clip_duration=2, + video_sampler=RandomSampler, + decode_audio=False, ) # 4. List the available models @@ -97,12 +97,11 @@ def make_transform( model.serializer = Labels() # 6. Finetune the model - trainer = flash.Trainer(max_epochs=3, gpus=1) + trainer = flash.Trainer(max_epochs=3) trainer.finetune(model, datamodule=datamodule, strategy=NoFreeze()) trainer.save_checkpoint("video_classification.pt") # 7. Make a prediction - val_folder = os.path.join(_PATH_ROOT, os.path.join(_PATH_ROOT, "data/kinetics/predict")) - predictions = model.predict([os.path.join(val_folder, f) for f in os.listdir(val_folder)]) + predictions = model.predict(os.path.join(_PATH_ROOT, "data/kinetics/predict")) print(predictions) diff --git a/flash_examples/predict/video_classification.py b/flash_examples/predict/video_classification.py index 0fd790b492..2bf8bff520 100644 --- a/flash_examples/predict/video_classification.py +++ b/flash_examples/predict/video_classification.py @@ -11,25 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import sys -from typing import Callable, List -import torch -from torch.utils.data.sampler import RandomSampler - -import flash -from flash.core.classification import Labels -from flash.core.finetuning import NoFreeze from flash.data.utils import download_data from flash.utils.imports import _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE -from flash.video import VideoClassificationData, VideoClassifier +from flash.video import VideoClassifier -if _PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE: - import kornia.augmentation as K - from pytorchvideo.transforms import ApplyTransformToKey, RandomShortSideScale, UniformTemporalSubsample - from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip -else: +if not (_PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE): print("Please, run `pip install torchvideo kornia`") sys.exit(0) @@ -41,6 +29,5 @@ ) # 2. Make a prediction -predict_folder = "data/kinetics/predict/" -predictions = model.predict([os.path.join(predict_folder, f) for f in os.listdir(predict_folder)]) +predictions = model.predict("data/kinetics/predict/") print(predictions) From 83024bb724b881a926b9dad0b1ab2e94094ee7b7 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 14:15:21 +0200 Subject: [PATCH 30/78] add smoke tests for autodataset --- tests/data/test_auto_dataset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index 2d50e671e4..e9c21992ca 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -19,6 +19,7 @@ from flash.data.auto_dataset import AutoDataset from flash.data.callback import FlashCallback from flash.data.data_pipeline import DataPipeline +from flash.data.data_source import DataSource from flash.data.process import Preprocess @@ -92,6 +93,18 @@ def train_load_data_with_dataset(self, data, dataset): return data +# TODO: we should test the different data types +@pytest.mark.parametrize("running_stage", [RunningStage.TRAINING, RunningStage.TESTING, RunningStage.VALIDATING]) +def test_autodataset_smoke(running_stage): + dset = AutoDataset(data=range(10), data_source=DataSource(), running_stage=running_stage) + assert dset is not None + assert dset.running_stage == running_stage + + # test set the running stage + dset.running_stage = RunningStage.PREDICTING + assert dset.running_stage == RunningStage.PREDICTING + + @pytest.mark.parametrize( "with_dataset,with_running_stage", [ From 8309080ae3041fcb5505bdb045d5331c64671593 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 14:28:39 +0200 Subject: [PATCH 31/78] improve autodataset test --- tests/data/test_auto_dataset.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index e9c21992ca..28f8a3ee3c 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -96,14 +96,25 @@ def train_load_data_with_dataset(self, data, dataset): # TODO: we should test the different data types @pytest.mark.parametrize("running_stage", [RunningStage.TRAINING, RunningStage.TESTING, RunningStage.VALIDATING]) def test_autodataset_smoke(running_stage): - dset = AutoDataset(data=range(10), data_source=DataSource(), running_stage=running_stage) + dt = range(10) + ds = DataSource() + dset = AutoDataset(data=dt, data_source=ds, running_stage=running_stage) assert dset is not None assert dset.running_stage == running_stage + # check on members + assert dset.data == dt + assert dset.data_source == ds + # test set the running stage dset.running_stage = RunningStage.PREDICTING assert dset.running_stage == RunningStage.PREDICTING + # check on methods + assert dset.load_sample is not None + assert dset.load_sample == ds.load_sample + pass + @pytest.mark.parametrize( "with_dataset,with_running_stage", From f1c44a1cb5cd7010f29d5ace5623ecdd5c11a9b2 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 13:30:02 +0100 Subject: [PATCH 32/78] Fix some tests --- flash/core/model.py | 4 +- flash/data/auto_dataset.py | 3 +- flash/data/data_source.py | 18 ++++--- flash/data/process.py | 21 +++++++++ flash/tabular/classification/data/data.py | 3 +- flash/text/classification/data.py | 3 +- flash/text/seq2seq/core/data.py | 1 + flash/video/classification/data.py | 3 +- flash/vision/classification/data.py | 3 +- flash/vision/detection/data.py | 3 +- flash_examples/predict/summarization.py | 47 +++++++++---------- .../predict/tabular_classification.py | 2 +- flash_examples/predict/text_classification.py | 17 +++---- flash_examples/predict/translation.py | 11 ++--- tests/core/test_model.py | 4 +- 15 files changed, 84 insertions(+), 59 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 8d390142ea..63c47874d8 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -158,7 +158,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: def predict( self, x: Any, - data_source: Union[str, DataSource] = DefaultDataSources.PATHS, + data_source: Union[str] = "default", data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ @@ -260,7 +260,7 @@ def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]): def build_data_pipeline( self, - data_source: Optional[Union[str, DataSource]] = None, + data_source: Optional[str] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 4ede9bb5b7..ad4131a78e 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -74,7 +74,8 @@ def running_stage(self, running_stage: RunningStage) -> None: def _call_load_sample(self, sample: Any) -> Any: if self.load_sample: - sample = dict(**sample) + if isinstance(sample, dict): + sample = dict(**sample) with self._load_sample_context: parameters = signature(self.load_sample).parameters if len(parameters) > 1 and self.DATASET_KEY in parameters: diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 3db2debe68..ce42b83bf0 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -209,10 +209,17 @@ def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx + @staticmethod + def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool: + try: + return os.path.isdir(data) + except TypeError: + return False + def load_data(self, data: Union[str, Tuple[List[str], List[Any]]], dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: - if isinstance(data, str) and os.path.isdir(data): + if self.isdir(data): classes, class_to_idx = self.find_classes(data) if not classes: return self.predict_load_data(data) @@ -234,11 +241,10 @@ def load_data(self, def predict_load_data(self, data: Union[str, List[str]], dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: - if isinstance(data, str): - if os.path.isdir(data): - data = [os.path.join(data, file) for file in os.listdir(data)] - else: - data = [data] + if self.isdir(data): + data = [os.path.join(data, file) for file in os.listdir(data)] + else: + data = [data] return list( filter( lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), diff --git a/flash/data/process.py b/flash/data/process.py index f2a5f9306d..619e9c2b1a 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -311,6 +311,7 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_sources: Optional[Dict[str, 'DataSource']] = None, + default_data_source: Optional[str] = None, ): super().__init__() @@ -337,6 +338,7 @@ def __init__( self.predict_transform = convert_to_modules(self._predict_transform) self._data_sources = data_sources + self._default_data_source = default_data_source self._callbacks: List[FlashCallback] = [] def _save_to_state_dict(self, destination, prefix, keep_vars): @@ -516,6 +518,8 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: return self.current_transform(batch) def data_source_of_name(self, data_source_name: str) -> Optional[DATA_SOURCE_TYPE]: + if data_source_name == "default": + data_source_name = self._default_data_source data_sources = self._data_sources if data_source_name in data_sources: return data_sources[data_source_name] @@ -524,6 +528,23 @@ def data_source_of_name(self, data_source_name: str) -> Optional[DATA_SOURCE_TYP class DefaultPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + from flash.data.data_source import DataSource + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={"default": DataSource()}, + default_data_source="default", + ) + def get_state_dict(self) -> Dict[str, Any]: return {} diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index f85500b422..4e3029e457 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -143,7 +143,8 @@ def __init__( "df": TabularDataFrameDataSource( cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression ), - } + }, + default_data_source=DefaultDataSources.CSV, ) def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index c3b1a4e6e7..44eeea257f 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -160,7 +160,8 @@ def __init__( data_sources={ DefaultDataSources.CSV: TextCSVDataSource(self.tokenizer, max_length=max_length), "sentences": TextSentencesDataSource(self.tokenizer, max_length=max_length), - } + }, + default_data_source="sentences", ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 410f3c58d4..c7e8823824 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -197,6 +197,7 @@ def __init__( padding=padding, ), }, + default_data_source="sentences", ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 02ec133308..0b035f786a 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -152,7 +152,8 @@ def __init__( decode_audio=decode_audio, decoder=decoder, ) - } + }, + default_data_source=DefaultDataSources.PATHS, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 65766e5e23..c212e5e0cc 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -56,7 +56,8 @@ def __init__( DefaultDataSources.PATHS: ImagePathsDataSource(), DefaultDataSources.NUMPY: ImageNumpyDataSource(), DefaultDataSources.TENSOR: ImageTensorDataSource(), - } + }, + default_data_source=DefaultDataSources.PATHS, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 6c1b0e6a97..5cb026f7b5 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -104,7 +104,8 @@ def __init__( predict_transform=predict_transform, data_sources={ "coco": COCODataSource(), - } + }, + default_data_source="coco", ) def collate(self, samples: Any) -> Any: diff --git a/flash_examples/predict/summarization.py b/flash_examples/predict/summarization.py index eb3358b1f4..3acaac05a9 100644 --- a/flash_examples/predict/summarization.py +++ b/flash_examples/predict/summarization.py @@ -23,31 +23,28 @@ model = SummarizationTask.load_from_checkpoint("../finetuning/summarization_model_xsum.pt") # 2a. Summarize an article! -predictions = model.predict( - [ - """ - Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local - people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. - They came to Brixton to see work which has started to revitalise the borough. - It was Charles' first visit to the area since 1996, when he was accompanied by the former - South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue - for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. - ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. - She asked me were they ripe and I said yes - they're from the Dominican Republic."" - Mr Chong is one of 170 local retailers who accept the Brixton Pound. - Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market - or in participating shops. - During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children - nearby on an estate off Coldharbour Lane. Mr West said: - ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" - He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" - Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. - The trust hopes to restore and refurbish the building, - where once Jimi Hendrix and The Clash played, as a new community and business centre." - """ - ], - data_source="sentences", -) +predictions = model.predict([ + """ + Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local + people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. + They came to Brixton to see work which has started to revitalise the borough. + It was Charles' first visit to the area since 1996, when he was accompanied by the former + South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue + for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. + ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. + She asked me were they ripe and I said yes - they're from the Dominican Republic."" + Mr Chong is one of 170 local retailers who accept the Brixton Pound. + Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market + or in participating shops. + During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children + nearby on an estate off Coldharbour Lane. Mr West said: + ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" + He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" + Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. + The trust hopes to restore and refurbish the building, + where once Jimi Hendrix and The Clash played, as a new community and business centre." + """ +]) print(predictions) # 2b. Or generate summaries from a sheet file! diff --git a/flash_examples/predict/tabular_classification.py b/flash_examples/predict/tabular_classification.py index dcee9c859d..88fb569c16 100644 --- a/flash_examples/predict/tabular_classification.py +++ b/flash_examples/predict/tabular_classification.py @@ -24,5 +24,5 @@ model.serializer = Labels(['Did not survive', 'Survived']) # 3. Generate predictions from a sheet file! Who would survive? -predictions = model.predict("data/titanic/titanic.csv", data_source="csv") +predictions = model.predict("data/titanic/titanic.csv") print(predictions) diff --git a/flash_examples/predict/text_classification.py b/flash_examples/predict/text_classification.py index 21d19854df..705fee9f92 100644 --- a/flash_examples/predict/text_classification.py +++ b/flash_examples/predict/text_classification.py @@ -26,16 +26,13 @@ model.serializer = Labels() # 2a. Classify a few sentences! How was the movie? -predictions = model.predict( - [ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado.", - "Very, very afraid.", - "This guy has done a great job with this movie!", - ], - data_source="sentences", -) +predictions = model.predict([ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", + "Very, very afraid.", + "This guy has done a great job with this movie!", +]) print(predictions) # 2b. Or generate predictions from a sheet file! diff --git a/flash_examples/predict/translation.py b/flash_examples/predict/translation.py index b7767c0718..112658ad33 100644 --- a/flash_examples/predict/translation.py +++ b/flash_examples/predict/translation.py @@ -23,11 +23,8 @@ model = TranslationTask.load_from_checkpoint("../finetuning/translation_model_en_ro.pt") # 3. Translate a few sentences! -predictions = model.predict( - [ - "BBC News went to meet one of the project's first graduates.", - "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", - ], - data_source="sentences", -) +predictions = model.predict([ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", +]) print(predictions) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 6a60071f74..df7dc89e33 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -29,7 +29,7 @@ import flash from flash.core.classification import ClassificationTask -from flash.data.process import Postprocess +from flash.data.process import DefaultPreprocess, Postprocess from flash.tabular import TabularClassifier from flash.text import SummarizationTask, TextClassifier from flash.utils.imports import _TRANSFORMERS_AVAILABLE @@ -75,7 +75,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): def test_classificationtask_task_predict(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) - task = ClassificationTask(model) + task = ClassificationTask(model, preprocess=DefaultPreprocess()) ds = DummyDataset() expected = list(range(10)) # single item From 47e8f3fa3f6baa42c55b57c245a4e6c2e6446ad9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 13:47:52 +0100 Subject: [PATCH 33/78] Fix a test --- tests/data/test_callback.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/data/test_callback.py b/tests/data/test_callback.py index 0bc47a91cd..26b1a941a0 100644 --- a/tests/data/test_callback.py +++ b/tests/data/test_callback.py @@ -11,19 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple from unittest import mock -from unittest.mock import ANY, call, MagicMock, Mock +from unittest.mock import ANY, call, MagicMock import torch -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.states import RunningStage -from torch import Tensor from flash.core.model import Task from flash.core.trainer import Trainer from flash.data.data_module import DataModule -from flash.data.process import Preprocess +from flash.data.process import DefaultPreprocess @mock.patch("torch.save") # need to mock torch.save or we get pickle error @@ -33,7 +30,9 @@ def test_flash_callback(_, tmpdir): callback_mock = MagicMock() inputs = [[torch.rand(1), torch.rand(1)]] - dm = DataModule.from_load_data_inputs(inputs, inputs, inputs, None, num_workers=0) + dm = DataModule.from_data_source( + "default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0 + ) dm.preprocess.callbacks += [callback_mock] _ = next(iter(dm.train_dataloader())) @@ -59,7 +58,9 @@ def __init__(self): limit_train_batches=1, progress_bar_refresh_rate=0, ) - dm = DataModule.from_load_data_inputs(inputs, inputs, inputs, None, num_workers=0) + dm = DataModule.from_data_source( + "default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0 + ) dm.preprocess.callbacks += [callback_mock] trainer.fit(CustomModel(), datamodule=dm) From f3a238e08bb86dead23999feb66f9ed132ee0165 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 14:01:04 +0100 Subject: [PATCH 34/78] Fixes --- flash/data/data_source.py | 4 +++- tests/data/test_callbacks.py | 44 ++++++++++++++++++------------------ 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index ce42b83bf0..b0fa2cd938 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -243,8 +243,10 @@ def predict_load_data(self, dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: if self.isdir(data): data = [os.path.join(data, file) for file in os.listdir(data)] - else: + + if not isinstance(data, list): data = [data] + return list( filter( lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 46a9347cfa..549b406b1b 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -25,7 +25,8 @@ from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule -from flash.data.process import DefaultPreprocess, Preprocess +from flash.data.data_source import DefaultDataKeys +from flash.data.process import DefaultPreprocess from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX from flash.vision import ImageClassificationData @@ -60,13 +61,14 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat preprocess = DefaultPreprocess() - return cls.from_load_data_inputs( - train_load_data_input=train_data, - val_load_data_input=val_data, - test_load_data_input=test_data, - predict_load_data_input=predict_data, + return cls.from_data_source( + "default", + train_data=train_data, + val_data=val_data, + test_data=test_data, + predict_data=predict_data, preprocess=preprocess, - batch_size=5 + batch_size=5, ) dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5)) @@ -133,14 +135,14 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: B: int = 2 # batch_size - dm = CustomImageClassificationData.from_filepaths( - train_filepaths=train_images, - train_labels=[0, 1], - val_filepaths=train_images, - val_labels=[2, 3], - test_filepaths=train_images, - test_labels=[4, 5], - predict_filepaths=train_images, + dm = CustomImageClassificationData.from_files( + train_files=train_images, + train_targets=[0, 1], + val_files=train_images, + val_targets=[2, 3], + test_files=train_images, + test_targets=[4, 5], + predict_files=train_images, batch_size=B, num_workers=0, ) @@ -157,9 +159,7 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: is_predict = stage == "predict" def _extract_data(data): - if not is_predict: - return data[0][0] - return data[0] + return data[0][DefaultDataKeys.INPUT] def _get_result(function_name: str): return dm.data_fetcher.batches[stage][function_name] @@ -170,7 +170,7 @@ def _get_result(function_name: str): if not is_predict: res = _get_result("load_sample") - assert isinstance(res[0][1], torch.Tensor) + assert isinstance(res[0][DefaultDataKeys.TARGET], int) res = _get_result("to_tensor_transform") assert len(res) == B @@ -178,21 +178,21 @@ def _get_result(function_name: str): if not is_predict: res = _get_result("to_tensor_transform") - assert isinstance(res[0][1], torch.Tensor) + assert isinstance(res[0][DefaultDataKeys.TARGET], torch.Tensor) res = _get_result("collate") assert _extract_data(res).shape == (B, 3, 196, 196) if not is_predict: res = _get_result("collate") - assert res[0][1].shape == torch.Size([2]) + assert res[0][DefaultDataKeys.TARGET].shape == torch.Size([2]) res = _get_result("per_batch_transform") assert _extract_data(res).shape == (B, 3, 196, 196) if not is_predict: res = _get_result("per_batch_transform") - assert res[0][1].shape == (B, ) + assert res[0][DefaultDataKeys.TARGET].shape == (B, ) assert dm.data_fetcher.show_load_sample_called assert dm.data_fetcher.show_pre_tensor_transform_called From eb5cfdd7e404da23b570e81f3b501c9ca8a9623e Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 15:25:20 +0200 Subject: [PATCH 35/78] add tests for base and iterable --- tests/data/test_auto_dataset.py | 67 +++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index 28f8a3ee3c..ca57b41329 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -16,7 +16,7 @@ import pytest from pytorch_lightning.trainer.states import RunningStage -from flash.data.auto_dataset import AutoDataset +from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset from flash.data.callback import FlashCallback from flash.data.data_pipeline import DataPipeline from flash.data.data_source import DataSource @@ -95,10 +95,10 @@ def train_load_data_with_dataset(self, data, dataset): # TODO: we should test the different data types @pytest.mark.parametrize("running_stage", [RunningStage.TRAINING, RunningStage.TESTING, RunningStage.VALIDATING]) -def test_autodataset_smoke(running_stage): +def test_base_autodataset_smoke(running_stage): dt = range(10) ds = DataSource() - dset = AutoDataset(data=dt, data_source=ds, running_stage=running_stage) + dset = BaseAutoDataset(data=dt, data_source=ds, running_stage=running_stage) assert dset is not None assert dset.running_stage == running_stage @@ -113,9 +113,65 @@ def test_autodataset_smoke(running_stage): # check on methods assert dset.load_sample is not None assert dset.load_sample == ds.load_sample - pass +def test_autodataset_smoke(): + num_samples = 20 + dt = range(num_samples) + ds = DataSource() + + dset = AutoDataset(data=dt, data_source=ds, running_stage=RunningStage.TRAINING) + assert dset is not None + assert dset.running_stage == RunningStage.TRAINING + + # check on members + assert dset.data == dt + assert dset.data_source == ds + + # test set the running stage + dset.running_stage = RunningStage.PREDICTING + assert dset.running_stage == RunningStage.PREDICTING + + # check on methods + assert dset.load_sample is not None + assert dset.load_sample == ds.load_sample + + # check getters + assert len(dset) == num_samples + assert dset[0] == 0 + assert dset[9] == 9 + assert dset[11] == 11 + + +def test_iterable_autodataset_smoke(): + num_samples = 20 + dt = range(num_samples) + ds = DataSource() + + dset = IterableAutoDataset(data=dt, data_source=ds, running_stage=RunningStage.TRAINING) + assert dset is not None + assert dset.running_stage == RunningStage.TRAINING + + # check on members + assert dset.data == dt + assert dset.data_source == ds + + # test set the running stage + dset.running_stage = RunningStage.PREDICTING + assert dset.running_stage == RunningStage.PREDICTING + + # check on methods + assert dset.load_sample is not None + assert dset.load_sample == ds.load_sample + + # check getters + itr = iter(dset) + assert next(itr) == 0 + assert next(itr) == 1 + assert next(itr) == 2 + + +# TODO: do we remove ? @pytest.mark.parametrize( "with_dataset,with_running_stage", [ @@ -161,6 +217,7 @@ def test_autodataset_with_functions( assert functions.load_sample_count == len(dset) +# TODO: do we remove ? def test_autodataset_warning(): with pytest.warns( UserWarning, match="``datapipeline`` is specified but load_sample and/or load_data are also specified" @@ -168,6 +225,7 @@ def test_autodataset_warning(): AutoDataset(range(10), load_data=lambda x: x, load_sample=lambda x: x, data_pipeline=DataPipeline()) +# TODO: do we remove ? @pytest.mark.parametrize( "with_dataset", [ @@ -197,6 +255,7 @@ def test_preprocessing_data_pipeline_with_running_stage(with_dataset): assert pipe._preprocess_pipeline.train_load_data_count == 1 +# TODO: do we remove ? we are testing DataPipeline here. @pytest.mark.parametrize( "with_dataset", [ From a997b9dbe9377a182477a8ed78d875bd806a6949 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 15:57:57 +0200 Subject: [PATCH 36/78] add todo with detected error in callbacks test --- tests/data/test_callbacks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 549b406b1b..057920a8c2 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -77,6 +77,8 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat with data_fetcher.enable(): _ = next(iter(dm.val_dataloader())) + # TODO: the method below fails because the data fetcher internally doesn't seem to cache + # properly the batches at each stage. data_fetcher.check() data_fetcher.reset() assert data_fetcher.batches == {'train': {}, 'test': {}, 'val': {}, 'predict': {}} From b18f0fd274fe0d6993b5401c503c486e6b34c099 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 16:03:20 +0200 Subject: [PATCH 37/78] fix test_data_pipeline_init_and_assignement --- tests/data/test_data_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 6b7ae78def..9caee72578 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -64,8 +64,8 @@ class SubPostprocess(Postprocess): pass data_pipeline = DataPipeline( - SubPreprocess() if use_preprocess else None, - SubPostprocess() if use_postprocess else None, + preprocess=SubPreprocess() if use_preprocess else None, + postprocess=SubPostprocess() if use_postprocess else None, ) assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) From bda0a124d939ce88dccbc2d84cfa10f4933c8c67 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 16:33:57 +0200 Subject: [PATCH 38/78] fix test_data_pipeline_is_overriden_and_resolve_function_hierarchy --- flash/data/utils.py | 7 +++++-- tests/data/test_data_pipeline.py | 31 ++++--------------------------- 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/flash/data/utils.py b/flash/data/utils.py index 48bac51a93..28b6313843 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -32,9 +32,12 @@ } _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} +_DATASOURCE_FUNCS: Set[str] = { + 'load_data', + 'load_sample', +} + _PREPROCESS_FUNCS: Set[str] = { - "load_data", - "load_sample", "pre_tensor_transform", "to_tensor_transform", "post_tensor_transform", diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 9caee72578..d9402ba530 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -72,6 +72,8 @@ class SubPostprocess(Postprocess): model = CustomModel(postprocess=Postprocess()) model.data_pipeline = data_pipeline + # TODO: the line below should make the same effect but it's not + # data_pipeline._attach_to_model(model) if use_preprocess: assert isinstance(model._preprocess, SubPreprocess) @@ -88,21 +90,6 @@ def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomPreprocess(DefaultPreprocess): - def load_data(self, *_, **__): - pass - - def test_load_data(self, *_, **__): - pass - - def predict_load_data(self, *_, **__): - pass - - def predict_load_sample(self, *_, **__): - pass - - def val_load_sample(self, *_, **__): - pass - def val_pre_tensor_transform(self, *_, **__): pass @@ -125,7 +112,8 @@ def test_per_batch_transform_on_device(self, *_, **__): pass preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) + train_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess @@ -150,17 +138,6 @@ def test_per_batch_transform_on_device(self, *_, **__): ) for k in data_pipeline.PREPROCESS_FUNCS } - # load_data - assert train_func_names["load_data"] == "load_data" - assert val_func_names["load_data"] == "load_data" - assert test_func_names["load_data"] == "test_load_data" - assert predict_func_names["load_data"] == "predict_load_data" - - # load_sample - assert train_func_names["load_sample"] == "load_sample" - assert val_func_names["load_sample"] == "val_load_sample" - assert test_func_names["load_sample"] == "load_sample" - assert predict_func_names["load_sample"] == "predict_load_sample" # pre_tensor_transform assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" From e4a4f8a421194acd7e7c77d461829d2ecd2cc044 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 15:39:39 +0100 Subject: [PATCH 39/78] Fix some tests --- flash/data/process.py | 6 +- flash/data/utils.py | 3 +- tests/data/test_data_pipeline.py | 212 ++++++++++++++----------------- 3 files changed, 102 insertions(+), 119 deletions(-) diff --git a/flash/data/process.py b/flash/data/process.py index 619e9c2b1a..01518b066f 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -534,6 +534,8 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + data_sources: Optional[Dict[str, 'DataSource']] = None, + default_data_source: Optional[str] = None, ): from flash.data.data_source import DataSource super().__init__( @@ -541,8 +543,8 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={"default": DataSource()}, - default_data_source="default", + data_sources=data_sources or {"default": DataSource()}, + default_data_source=default_data_source or "default", ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/data/utils.py b/flash/data/utils.py index 48bac51a93..641306389d 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -33,8 +33,7 @@ _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} _PREPROCESS_FUNCS: Set[str] = { - "load_data", - "load_sample", + # "load_sample", # TODO: This should still be a callback hook "pre_tensor_transform", "to_tensor_transform", "post_tensor_transform", diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 6b7ae78def..0ae5d400bf 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -32,6 +32,7 @@ from flash.data.batch import _PostProcessor, _PreProcessor from flash.data.data_module import DataModule from flash.data.data_pipeline import _StageOrchestrator, DataPipeline +from flash.data.data_source import DataSource from flash.data.process import DefaultPreprocess, Postprocess, Preprocess @@ -64,10 +65,11 @@ class SubPostprocess(Postprocess): pass data_pipeline = DataPipeline( + None, SubPreprocess() if use_preprocess else None, SubPostprocess() if use_postprocess else None, ) - assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess) + assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else DefaultPreprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) model = CustomModel(postprocess=Postprocess()) @@ -88,21 +90,6 @@ def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomPreprocess(DefaultPreprocess): - def load_data(self, *_, **__): - pass - - def test_load_data(self, *_, **__): - pass - - def predict_load_data(self, *_, **__): - pass - - def predict_load_sample(self, *_, **__): - pass - - def val_load_sample(self, *_, **__): - pass - def val_pre_tensor_transform(self, *_, **__): pass @@ -125,7 +112,7 @@ def test_per_batch_transform_on_device(self, *_, **__): pass preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) train_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess @@ -150,17 +137,6 @@ def test_per_batch_transform_on_device(self, *_, **__): ) for k in data_pipeline.PREPROCESS_FUNCS } - # load_data - assert train_func_names["load_data"] == "load_data" - assert val_func_names["load_data"] == "load_data" - assert test_func_names["load_data"] == "test_load_data" - assert predict_func_names["load_data"] == "predict_load_data" - - # load_sample - assert train_func_names["load_sample"] == "load_sample" - assert val_func_names["load_sample"] == "val_load_sample" - assert test_func_names["load_sample"] == "load_sample" - assert predict_func_names["load_sample"] == "predict_load_sample" # pre_tensor_transform assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" @@ -271,7 +247,7 @@ def predict_per_batch_transform_on_device(self, *_, **__): def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) data_pipeline.worker_preprocessor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="are mutual exclusive"): @@ -293,7 +269,7 @@ def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) model = CustomModel() model.data_pipeline = data_pipeline @@ -343,7 +319,7 @@ class SubPreprocess(DefaultPreprocess): pass preprocess = SubPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) class CustomModel(Task): @@ -491,7 +467,7 @@ def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _PostProc model.predict_step = self._model_predict_step_wrapper(model.predict_step, _postprocesssor, model) return model - data_pipeline = CustomDataPipeline(preprocess) + data_pipeline = CustomDataPipeline(preprocess=preprocess) _postprocesssor = data_pipeline._create_uncollate_postprocessors(RunningStage.PREDICTING) data_pipeline._attach_postprocess_to_model(model, _postprocesssor) assert model.predict_step._original == _original_predict_step @@ -512,23 +488,15 @@ def __len__(self) -> int: return 5 -class TestPreprocessTransformations(DefaultPreprocess): +class TestPreprocessTransformationsDataSource(DataSource): def __init__(self): super().__init__() self.train_load_data_called = False - self.train_pre_tensor_transform_called = False - self.train_collate_called = False - self.train_per_batch_transform_on_device_called = False self.val_load_data_called = False self.val_load_sample_called = False - self.val_to_tensor_transform_called = False - self.val_collate_called = False - self.val_per_batch_transform_on_device_called = False self.test_load_data_called = False - self.test_to_tensor_transform_called = False - self.test_post_tensor_transform_called = False self.predict_load_data_called = False @staticmethod @@ -546,6 +514,53 @@ def train_load_data(self, sample) -> LamdaDummyDataset: self.train_load_data_called = True return LamdaDummyDataset(self.fn_train_load_data) + def val_load_data(self, sample, dataset) -> List[int]: + assert self.validating + assert self.current_fn == "load_data" + self.val_load_data_called = True + return list(range(5)) + + def val_load_sample(self, sample) -> Dict[str, Tensor]: + assert self.validating + assert self.current_fn == "load_sample" + self.val_load_sample_called = True + return {"a": sample, "b": sample + 1} + + @staticmethod + def fn_test_load_data() -> List[torch.Tensor]: + return [torch.rand(1), torch.rand(1)] + + def test_load_data(self, sample) -> LamdaDummyDataset: + assert self.testing + assert self.current_fn == "load_data" + self.test_load_data_called = True + return LamdaDummyDataset(self.fn_test_load_data) + + @staticmethod + def fn_predict_load_data() -> List[str]: + return (["a", "b"]) + + def predict_load_data(self, sample) -> LamdaDummyDataset: + assert self.predicting + assert self.current_fn == "load_data" + self.predict_load_data_called = True + return LamdaDummyDataset(self.fn_predict_load_data) + + +class TestPreprocessTransformations(DefaultPreprocess): + + def __init__(self): + super().__init__(data_sources={"default": TestPreprocessTransformationsDataSource()}) + + self.train_pre_tensor_transform_called = False + self.train_collate_called = False + self.train_per_batch_transform_on_device_called = False + self.val_to_tensor_transform_called = False + self.val_collate_called = False + self.val_per_batch_transform_on_device_called = False + self.test_to_tensor_transform_called = False + self.test_post_tensor_transform_called = False + def train_pre_tensor_transform(self, sample: Any) -> Any: assert self.training assert self.current_fn == "pre_tensor_transform" @@ -564,19 +579,6 @@ def train_per_batch_transform_on_device(self, batch: Any) -> Any: self.train_per_batch_transform_on_device_called = True assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) - def val_load_data(self, sample, dataset) -> List[int]: - assert self.validating - assert self.current_fn == "load_data" - self.val_load_data_called = True - assert isinstance(dataset, AutoDataset) - return list(range(5)) - - def val_load_sample(self, sample) -> Dict[str, Tensor]: - assert self.validating - assert self.current_fn == "load_sample" - self.val_load_sample_called = True - return {"a": sample, "b": sample + 1} - def val_to_tensor_transform(self, sample: Any) -> Tensor: assert self.validating assert self.current_fn == "to_tensor_transform" @@ -601,16 +603,6 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: assert torch.equal(batch["b"], tensor([1, 2])) return [False] - @staticmethod - def fn_test_load_data() -> List[torch.Tensor]: - return [torch.rand(1), torch.rand(1)] - - def test_load_data(self, sample) -> LamdaDummyDataset: - assert self.testing - assert self.current_fn == "load_data" - self.test_load_data_called = True - return LamdaDummyDataset(self.fn_test_load_data) - def test_to_tensor_transform(self, sample: Any) -> Tensor: assert self.testing assert self.current_fn == "to_tensor_transform" @@ -623,16 +615,6 @@ def test_post_tensor_transform(self, sample: Tensor) -> Tensor: self.test_post_tensor_transform_called = True return sample - @staticmethod - def fn_predict_load_data() -> List[str]: - return (["a", "b"]) - - def predict_load_data(self, sample) -> LamdaDummyDataset: - assert self.predicting - assert self.current_fn == "load_data" - self.predict_load_data_called = True - return LamdaDummyDataset(self.fn_predict_load_data) - class TestPreprocessTransformations2(TestPreprocessTransformations): @@ -668,8 +650,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx): def test_datapipeline_transformations(tmpdir): - datamodule = DataModule.from_load_data_inputs( - 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations() + datamodule = DataModule.from_data_source( + "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations() ) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) @@ -681,8 +663,8 @@ def test_datapipeline_transformations(tmpdir): with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"): batch = next(iter(datamodule.val_dataloader())) - datamodule = DataModule.from_load_data_inputs( - 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations2() + datamodule = DataModule.from_data_source( + "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations2() ) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], tensor([0, 1])) @@ -702,19 +684,20 @@ def test_datapipeline_transformations(tmpdir): trainer.predict(model) preprocess = model._preprocess - assert preprocess.train_load_data_called + data_source = preprocess.data_source_of_name("default") + assert data_source.train_load_data_called assert preprocess.train_pre_tensor_transform_called assert preprocess.train_collate_called assert preprocess.train_per_batch_transform_on_device_called - assert preprocess.val_load_data_called - assert preprocess.val_load_sample_called + assert data_source.val_load_data_called + assert data_source.val_load_sample_called assert preprocess.val_to_tensor_transform_called assert preprocess.val_collate_called assert preprocess.val_per_batch_transform_on_device_called - assert preprocess.test_load_data_called + assert data_source.test_load_data_called assert preprocess.test_to_tensor_transform_called assert preprocess.test_post_tensor_transform_called - assert preprocess.predict_load_data_called + assert data_source.predict_load_data_called def test_is_overriden_recursive(tmpdir): @@ -741,12 +724,7 @@ def val_collate(self, *_): @mock.patch("torch.save") # need to mock torch.save or we get pickle error def test_dummy_example(tmpdir): - class ImageClassificationPreprocess(DefaultPreprocess): - - def __init__(self, to_tensor_transform, train_per_sample_transform_on_device): - super().__init__() - self._to_tensor = to_tensor_transform - self._train_per_sample_transform_on_device = train_per_sample_transform_on_device + class ImageDataSource(DataSource): def load_data(self, folder: str): # from folder -> return files paths @@ -757,6 +735,27 @@ def load_sample(self, path: str) -> Image.Image: img8Bit = np.uint8(np.random.uniform(0, 1, (64, 64, 3)) * 255.0) return Image.fromarray(img8Bit) + class ImageClassificationPreprocess(DefaultPreprocess): + + def __init__( + self, + train_transform=None, + val_transform=None, + test_transform=None, + predict_transform=None, + to_tensor_transform=None, + train_per_sample_transform_on_device=None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={"default": ImageDataSource()}, + ) + self._to_tensor = to_tensor_transform + self._train_per_sample_transform_on_device = train_per_sample_transform_on_device + def to_tensor_transform(self, pil_image: Image.Image) -> Tensor: # convert pil image into a tensor return self._to_tensor(pil_image) @@ -783,32 +782,15 @@ class CustomDataModule(DataModule): preprocess_cls = ImageClassificationPreprocess - @property - def preprocess(self): - return self.preprocess_cls(self.to_tensor_transform, self.train_per_sample_transform_on_device) - - @classmethod - def from_folders( - cls, train_folder: Optional[str], val_folder: Optional[str], test_folder: Optional[str], - predict_folder: Optional[str], to_tensor_transform: torch.nn.Module, - train_per_sample_transform_on_device: torch.nn.Module, batch_size: int - ): - - # attach the arguments for the preprocess onto the cls - cls.to_tensor_transform = to_tensor_transform - cls.train_per_sample_transform_on_device = train_per_sample_transform_on_device - - # call ``from_load_data_inputs`` - return cls.from_load_data_inputs( - train_load_data_input=train_folder, - val_load_data_input=val_folder, - test_load_data_input=test_folder, - predict_load_data_input=predict_folder, - batch_size=batch_size - ) - - datamodule = CustomDataModule.from_folders( - "train_folder", "val_folder", "test_folder", None, T.ToTensor(), T.RandomHorizontalFlip(), batch_size=2 + datamodule = CustomDataModule.from_data_source( + "default", + "train_folder", + "val_folder", + "test_folder", + None, + batch_size=2, + to_tensor_transform=T.ToTensor(), + train_per_sample_transform_on_device=T.RandomHorizontalFlip(), ) assert isinstance(datamodule.train_dataloader().dataset[0], Image.Image) From 464fffea7c235c7885947162c5d2c46512ca26d1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 16:53:03 +0100 Subject: [PATCH 40/78] Fix some tests --- tests/data/test_data_pipeline.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 0fa7ba2417..721213a816 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -849,10 +849,10 @@ def test_preprocess_transforms(tmpdir): assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False - train_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.TRAINING) - val_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.VALIDATING) - test_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.TESTING) - predict_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.PREDICTING) + train_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING) + val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.VALIDATING) + test_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.TESTING) + predict_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == default_collate assert val_preprocessor.collate_fn.func == default_collate @@ -877,7 +877,7 @@ def per_batch_transform(self, batch: Any) -> Any: assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) train_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): @@ -892,14 +892,12 @@ def per_batch_transform(self, batch: Any) -> Any: def test_iterable_auto_dataset(tmpdir): - class CustomPreprocess(DefaultPreprocess): + class CustomDataSource(DataSource): def load_sample(self, index: int) -> Dict[str, int]: return {"index": index} - data_pipeline = DataPipeline(CustomPreprocess()) - - ds = IterableAutoDataset(range(10), running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline) + ds = IterableAutoDataset(range(10), data_source=CustomDataSource(), running_stage=RunningStage.TRAINING) for index, v in enumerate(ds): assert v == {"index": index} From e7d6b6605b8b9d60f39387145ad2eb15eb99c770 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 17:06:12 +0100 Subject: [PATCH 41/78] Fix some tests --- flash/core/classification.py | 3 +++ flash/data/data_pipeline.py | 2 +- tests/data/test_process.py | 6 +++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 34d2db4a1d..b85a529b3a 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -135,6 +135,9 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False super().__init__(multi_label=multi_label, threshold=threshold) self._labels = labels + if labels is not None: + self.set_state(LabelsState(labels)) + def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: labels = None diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 266fdfa7f7..cbd6ba4700 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -103,7 +103,7 @@ def __init__( self._running_stage = None - def initialize(self, data_pipeline_state: Optional[DataPipelineState]) -> DataPipelineState: + def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> DataPipelineState: """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.Preprocess`, :class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will give a warning.""" diff --git a/tests/data/test_process.py b/tests/data/test_process.py index efbbd82d2c..6f4d59f3d0 100644 --- a/tests/data/test_process.py +++ b/tests/data/test_process.py @@ -120,7 +120,7 @@ def __init__(self): serializer = Labels(["a", "b"]) model = CustomModel() trainer = Trainer(fast_dev_run=True) - data_pipeline = DataPipeline(DefaultPreprocess(), serializer=serializer) + data_pipeline = DataPipeline(preprocess=DefaultPreprocess(), serializer=serializer) data_pipeline.initialize() model.data_pipeline = data_pipeline assert isinstance(model.preprocess, DefaultPreprocess) @@ -128,5 +128,5 @@ def __init__(self): trainer.fit(model, train_dataloader=dummy_data) trainer.save_checkpoint(checkpoint_file) model = CustomModel.load_from_checkpoint(checkpoint_file) - assert isinstance(model.preprocess._data_pipeline_state, DataPipelineState) - assert model.preprocess._data_pipeline_state._state[LabelsState] == LabelsState(['a', 'b']) + assert isinstance(model._data_pipeline_state, DataPipelineState) + assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"]) From cc57f866e3caee890c25cea6ddda06dd70c958fa Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 17:07:38 +0100 Subject: [PATCH 42/78] Fix a test --- tests/data/test_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_serialization.py b/tests/data/test_serialization.py index 54d5ae40e6..bb166eeec8 100644 --- a/tests/data/test_serialization.py +++ b/tests/data/test_serialization.py @@ -53,7 +53,7 @@ def test_serialization_data_pipeline(tmpdir): loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline - model.data_pipeline = DataPipeline(CustomPreprocess()) + model.data_pipeline = DataPipeline(preprocess=CustomPreprocess()) assert isinstance(model.preprocess, CustomPreprocess) trainer.fit(model, dummy_data) From 3a6308349cfafa78bcf917eb5c7c2793fafb86b1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 17:09:37 +0100 Subject: [PATCH 43/78] Fixes --- tests/tabular/classification/test_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index a1055f2711..393597118f 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -14,6 +14,7 @@ import torch from pytorch_lightning import Trainer +from flash.data.data_source import DefaultDataKeys from flash.tabular import TabularClassifier # ======== Mock functions ======== @@ -30,7 +31,7 @@ def __getitem__(self, index): target = torch.randint(0, 10, size=(1, )).item() cat_vars = torch.randint(0, 10, size=(self.num_cat, )) num_vars = torch.rand(self.num_num) - return (cat_vars, num_vars), target + return {DefaultDataKeys.INPUT: (cat_vars, num_vars), DefaultDataKeys.TARGET: target} def __len__(self) -> int: return 100 From 348995321f73493280b4fd16746e3114fafc6cb9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 17:14:53 +0100 Subject: [PATCH 44/78] Fixes --- tests/tabular/data/test_data.py | 42 ++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/tests/tabular/data/test_data.py b/tests/tabular/data/test_data.py index 1a181a5487..99f9432c42 100644 --- a/tests/tabular/data/test_data.py +++ b/tests/tabular/data/test_data.py @@ -18,6 +18,7 @@ import pandas as pd import pytest +from flash.data.data_source import DefaultDataKeys from flash.tabular import TabularData from flash.tabular.classification.data.dataset import _categorize, _normalize @@ -86,17 +87,19 @@ def test_tabular_data(tmpdir): val_df = TEST_DF_2.copy() test_df = TEST_DF_2.copy() dm = TabularData.from_df( - train_df, categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", + train_df=train_df, val_df=val_df, test_df=test_df, num_workers=0, batch_size=1, ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - (cat, num), target = next(iter(dl)) + data = next(iter(dl)) + (cat, num) = data[DefaultDataKeys.INPUT] + target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) @@ -111,17 +114,19 @@ def test_categorical_target(tmpdir): df["label"] = df["label"].astype(str) dm = TabularData.from_df( - train_df, categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", + train_df=train_df, val_df=val_df, test_df=test_df, num_workers=0, batch_size=1, ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - (cat, num), target = next(iter(dl)) + data = next(iter(dl)) + (cat, num) = data[DefaultDataKeys.INPUT] + target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) @@ -132,17 +137,19 @@ def test_from_df(tmpdir): val_df = TEST_DF_2.copy() test_df = TEST_DF_2.copy() dm = TabularData.from_df( - train_df, categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", + train_df=train_df, val_df=val_df, test_df=test_df, num_workers=0, batch_size=1 ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - (cat, num), target = next(iter(dl)) + data = next(iter(dl)) + (cat, num) = data[DefaultDataKeys.INPUT] + target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) @@ -156,17 +163,19 @@ def test_from_csv(tmpdir): TEST_DF_2.to_csv(test_csv) dm = TabularData.from_csv( - train_csv=train_csv, - categorical_cols=["category"], - numerical_cols=["scalar_b", "scalar_b"], - target_col="label", - val_csv=val_csv, - test_csv=test_csv, + categorical_fields=["category"], + numerical_fields=["scalar_b", "scalar_b"], + target_field="label", + train_file=train_csv, + val_file=val_csv, + test_file=test_csv, num_workers=0, batch_size=1 ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - (cat, num), target = next(iter(dl)) + data = next(iter(dl)) + (cat, num) = data[DefaultDataKeys.INPUT] + target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) @@ -176,5 +185,10 @@ def test_empty_inputs(): train_df = TEST_DF_1.copy() with pytest.raises(RuntimeError): TabularData.from_df( - train_df, numerical_cols=None, categorical_cols=None, target_col="label", num_workers=0, batch_size=1 + numerical_cols=None, + categorical_cols=None, + target_col="label", + train_df=train_df, + num_workers=0, + batch_size=1, ) From 1ccf7aba0236f2fac9e3146227cb76e847276e3c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 17:15:53 +0100 Subject: [PATCH 45/78] Fixes --- tests/tabular/test_data_model_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tabular/test_data_model_integration.py b/tests/tabular/test_data_model_integration.py index 6c022eba0f..0f4cb26b62 100644 --- a/tests/tabular/test_data_model_integration.py +++ b/tests/tabular/test_data_model_integration.py @@ -32,10 +32,10 @@ def test_classification(tmpdir): val_df = TEST_DF_1.copy() test_df = TEST_DF_1.copy() data = TabularData.from_df( - train_df, categorical_cols=["category"], numerical_cols=["scalar_a", "scalar_b"], target_col="label", + train_df=train_df, val_df=val_df, test_df=test_df, num_workers=0, From 64aff9eca2afbe11c03a13125050ccd33fedac85 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 17:22:52 +0100 Subject: [PATCH 46/78] Fixes --- flash/text/classification/data.py | 33 ++++++++++++++++++++++++++ tests/text/classification/test_data.py | 14 ++++------- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 44eeea257f..05ab62df4f 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -159,6 +159,7 @@ def __init__( predict_transform=predict_transform, data_sources={ DefaultDataSources.CSV: TextCSVDataSource(self.tokenizer, max_length=max_length), + "json": TextJSONDataSource(self.tokenizer, max_length=max_length), "sentences": TextSentencesDataSource(self.tokenizer, max_length=max_length), }, default_data_source="sentences", @@ -202,6 +203,38 @@ class TextClassificationData(DataModule): preprocess_cls = TextClassificationPreprocess postprocess_cls = TextClassificationPostProcess + @classmethod + def from_json( + cls, + input_fields: Union[str, List[str]], + target_fields: Optional[Union[str, List[str]]] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + backbone: str = "prajjwal1/bert-tiny", + max_length: int = 128, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + ) -> 'DataModule': + return super().from_data_source( + "json", + train_data=(train_file, input_fields, target_fields), + val_data=(val_file, input_fields, target_fields), + test_data=(test_file, input_fields, target_fields), + predict_data=(predict_file, input_fields, target_fields), + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + backbone=backbone, + max_length=max_length, + ) + @classmethod def from_csv( cls, diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 3df3360030..866b9d9328 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -48,9 +48,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) - dm = TextClassificationData.from_files( - backbone=TEST_BACKBONE, train_file=csv_path, input="sentence", target="label", batch_size=1 - ) + dm = TextClassificationData.from_csv("sentence", "label", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert batch["labels"].item() in [0, 1] assert "input_ids" in batch @@ -59,13 +57,13 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_test_valid(tmpdir): csv_path = csv_data(tmpdir) - dm = TextClassificationData.from_files( + dm = TextClassificationData.from_csv( + "sentence", + "label", backbone=TEST_BACKBONE, train_file=csv_path, val_file=csv_path, test_file=csv_path, - input="sentence", - target="label", batch_size=1 ) batch = next(iter(dm.val_dataloader())) @@ -80,9 +78,7 @@ def test_test_valid(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): json_path = json_data(tmpdir) - dm = TextClassificationData.from_files( - backbone=TEST_BACKBONE, train_file=json_path, input="sentence", target="lab", filetype="json", batch_size=1 - ) + dm = TextClassificationData.from_json("sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert batch["labels"].item() in [0, 1] assert "input_ids" in batch From 33506b311a4dbcc28315bb6d2729c8741a0c4a70 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 17:39:02 +0100 Subject: [PATCH 47/78] Fixes --- flash/data/data_module.py | 30 ++++++++++ flash/data/data_source.py | 1 + flash/text/classification/data.py | 67 +---------------------- flash/text/seq2seq/core/data.py | 6 ++ tests/text/summarization/test_data.py | 14 ++--- tests/text/test_data_model_integration.py | 6 +- tests/text/translation/test_data.py | 14 ++--- 7 files changed, 51 insertions(+), 87 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 9fce1d4ed1..0e920e7be9 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -535,6 +535,36 @@ def from_numpy( **preprocess_kwargs, ) + @classmethod + def from_json( + cls, + input_fields: Union[str, List[str]], + target_fields: Optional[Union[str, List[str]]] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.JSON, + (train_file, input_fields, target_fields), + (val_file, input_fields, target_fields), + (test_file, input_fields, target_fields), + (predict_file, input_fields, target_fields), + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + @classmethod def from_csv( cls, diff --git a/flash/data/data_source.py b/flash/data/data_source.py index b0fa2cd938..b4846a74ea 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -136,6 +136,7 @@ class DefaultDataSources(LightningEnum): NUMPY = "numpy" TENSOR = "tensor" CSV = "csv" + JSON = "json" # TODO: Create a FlashEnum class??? def __hash__(self) -> int: diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 05ab62df4f..145ce65454 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -159,7 +159,7 @@ def __init__( predict_transform=predict_transform, data_sources={ DefaultDataSources.CSV: TextCSVDataSource(self.tokenizer, max_length=max_length), - "json": TextJSONDataSource(self.tokenizer, max_length=max_length), + DefaultDataSources.JSON: TextJSONDataSource(self.tokenizer, max_length=max_length), "sentences": TextSentencesDataSource(self.tokenizer, max_length=max_length), }, default_data_source="sentences", @@ -202,68 +202,3 @@ class TextClassificationData(DataModule): preprocess_cls = TextClassificationPreprocess postprocess_cls = TextClassificationPostProcess - - @classmethod - def from_json( - cls, - input_fields: Union[str, List[str]], - target_fields: Optional[Union[str, List[str]]] = None, - train_file: Optional[str] = None, - val_file: Optional[str] = None, - test_file: Optional[str] = None, - predict_file: Optional[str] = None, - backbone: str = "prajjwal1/bert-tiny", - max_length: int = 128, - data_fetcher: BaseDataFetcher = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - ) -> 'DataModule': - return super().from_data_source( - "json", - train_data=(train_file, input_fields, target_fields), - val_data=(val_file, input_fields, target_fields), - test_data=(test_file, input_fields, target_fields), - predict_data=(predict_file, input_fields, target_fields), - data_fetcher=data_fetcher, - preprocess=preprocess, - val_split=val_split, - batch_size=batch_size, - num_workers=num_workers, - backbone=backbone, - max_length=max_length, - ) - - @classmethod - def from_csv( - cls, - input_fields: Union[str, List[str]], - target_fields: Optional[Union[str, List[str]]] = None, - train_file: Optional[str] = None, - val_file: Optional[str] = None, - test_file: Optional[str] = None, - predict_file: Optional[str] = None, - backbone: str = "prajjwal1/bert-tiny", - max_length: int = 128, - data_fetcher: BaseDataFetcher = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - ) -> 'DataModule': - return super().from_csv( - input_fields, - target_fields, - train_file=train_file, - val_file=val_file, - test_file=test_file, - predict_file=predict_file, - data_fetcher=data_fetcher, - preprocess=preprocess, - val_split=val_split, - batch_size=batch_size, - num_workers=num_workers, - backbone=backbone, - max_length=max_length, - ) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index c7e8823824..3fbfecd6df 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -190,6 +190,12 @@ def __init__( max_target_length=max_target_length, padding=padding, ), + DefaultDataSources.JSON: Seq2SeqJSONDataSource( + self.tokenizer, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ), "sentences": Seq2SeqSentencesDataSource( self.tokenizer, max_source_length=max_source_length, diff --git a/tests/text/summarization/test_data.py b/tests/text/summarization/test_data.py index 616a9d6f53..67b88bc937 100644 --- a/tests/text/summarization/test_data.py +++ b/tests/text/summarization/test_data.py @@ -48,9 +48,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) - dm = SummarizationData.from_files( - backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1 - ) + dm = SummarizationData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch @@ -59,13 +57,13 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) - dm = SummarizationData.from_files( + dm = SummarizationData.from_csv( + "input", + "target", backbone=TEST_BACKBONE, train_file=csv_path, val_file=csv_path, test_file=csv_path, - input="input", - target="target", batch_size=1 ) batch = next(iter(dm.val_dataloader())) @@ -80,9 +78,7 @@ def test_from_files(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): json_path = json_data(tmpdir) - dm = SummarizationData.from_files( - backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 - ) + dm = SummarizationData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch diff --git a/tests/text/test_data_model_integration.py b/tests/text/test_data_model_integration.py index 7aeadba7de..91c10bb049 100644 --- a/tests/text/test_data_model_integration.py +++ b/tests/text/test_data_model_integration.py @@ -39,11 +39,11 @@ def test_classification(tmpdir): csv_path = csv_data(tmpdir) - data = TextClassificationData.from_files( + data = TextClassificationData.from_csv( + "sentence", + "label", backbone=TEST_BACKBONE, train_file=csv_path, - input="sentence", - target="label", num_workers=0, batch_size=2, ) diff --git a/tests/text/translation/test_data.py b/tests/text/translation/test_data.py index d9e17105ce..859bd1fe7a 100644 --- a/tests/text/translation/test_data.py +++ b/tests/text/translation/test_data.py @@ -48,9 +48,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) - dm = TranslationData.from_files( - backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1 - ) + dm = TranslationData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch @@ -59,13 +57,13 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) - dm = TranslationData.from_files( + dm = TranslationData.from_csv( + "input", + "target", backbone=TEST_BACKBONE, train_file=csv_path, val_file=csv_path, test_file=csv_path, - input="input", - target="target", batch_size=1 ) batch = next(iter(dm.val_dataloader())) @@ -80,9 +78,7 @@ def test_from_files(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): json_path = json_data(tmpdir) - dm = TranslationData.from_files( - backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 - ) + dm = TranslationData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch From 1f50432c3694f70d58790fc8f991d95b0be00dbd Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 18:39:29 +0200 Subject: [PATCH 48/78] deprecate csv test for image classification --- tests/vision/classification/test_data.py | 162 +++++++---------------- 1 file changed, 51 insertions(+), 111 deletions(-) diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index ad21f53aca..d79817d7ae 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -43,9 +43,14 @@ def test_from_filepaths_smoke(tmpdir): _rand_image().save(tmpdir / "a_1.png") _rand_image().save(tmpdir / "b_1.png") - img_data = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a_1.png", tmpdir / "b_1.png"], - train_labels=[1, 2], + train_images = [ + str(tmpdir / "a_1.png"), + str(tmpdir / "b_1.png"), + ] + + img_data = ImageClassificationData.from_files( + train_files=train_images, + train_targets=[1, 2], batch_size=2, num_workers=0, ) @@ -54,7 +59,7 @@ def test_from_filepaths_smoke(tmpdir): assert img_data.test_dataloader() is None data = next(iter(img_data.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert sorted(list(labels.numpy())) == [1, 2] @@ -72,20 +77,20 @@ def test_from_filepaths_list_image_paths(tmpdir): str(tmpdir / "e_1.png"), ] - img_data = ImageClassificationData.from_filepaths( - train_filepaths=train_images, - train_labels=[0, 3, 6], - val_filepaths=train_images, - val_labels=[1, 4, 7], - test_filepaths=train_images, - test_labels=[2, 5, 8], + img_data = ImageClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], batch_size=2, num_workers=0, ) # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here @@ -93,14 +98,14 @@ def test_from_filepaths_list_image_paths(tmpdir): # check validation data data = next(iter(img_data.val_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [1, 4] # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [2, 5] @@ -109,27 +114,33 @@ def test_from_filepaths_list_image_paths(tmpdir): def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) - (tmpdir / "a").mkdir() - (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "b" / "b_1.png") + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") - dm = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[0, 1], - val_filepaths=[tmpdir / "b", tmpdir / "a"], - val_labels=[0, 2], - test_filepaths=[tmpdir / "b", tmpdir / "b"], - test_labels=[2, 1], + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + dm = ImageClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], batch_size=2, + num_workers=0, ) + # disable visualisation for testing assert dm.data_fetcher.block_viz_window is True dm.set_block_viz_window(False) assert dm.data_fetcher.block_viz_window is False # call show functions - dm.show_train_batch() + # dm.show_train_batch() dm.show_train_batch("pre_tensor_transform") dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) @@ -209,77 +220,6 @@ def run(transform: Any = None): run(_to_tensor) -def test_categorical_csv_labels(tmpdir): - train_dir = Path(tmpdir / "some_dataset") - train_dir.mkdir() - - (train_dir / "train").mkdir() - _rand_image().save(train_dir / "train" / "train_1.png") - _rand_image().save(train_dir / "train" / "train_2.png") - - (train_dir / "valid").mkdir() - _rand_image().save(train_dir / "valid" / "val_1.png") - _rand_image().save(train_dir / "valid" / "val_2.png") - - (train_dir / "test").mkdir() - _rand_image().save(train_dir / "test" / "test_1.png") - _rand_image().save(train_dir / "test" / "test_2.png") - - train_csv = os.path.join(tmpdir, 'some_dataset', 'train.csv') - text_file = open(train_csv, 'w') - text_file.write( - 'my_id,label_a,label_b,label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n' - ) - text_file.close() - - val_csv = os.path.join(tmpdir, 'some_dataset', 'valid.csv') - text_file = open(val_csv, 'w') - text_file.write('my_id,label_a,label_b,label_c\n"val_1.png", 0, 1, 0\n"val_2.png", 0, 0, 1\n"val_3.png", 1, 0, 0\n') - text_file.close() - - test_csv = os.path.join(tmpdir, 'some_dataset', 'test.csv') - text_file = open(test_csv, 'w') - text_file.write( - 'my_id,label_a,label_b,label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n' - ) - text_file.close() - - def index_col_collate_fn(x): - return os.path.splitext(x)[0] - - train_labels = labels_from_categorical_csv( - train_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn - ) - val_labels = labels_from_categorical_csv( - val_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn - ) - test_labels = labels_from_categorical_csv( - test_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn - ) - B: int = 2 # batch_size - data = ImageClassificationData.from_filepaths( - batch_size=B, - train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'), - train_labels=train_labels.values(), - val_filepaths=os.path.join(tmpdir, 'some_dataset', 'valid'), - val_labels=val_labels.values(), - test_filepaths=os.path.join(tmpdir, 'some_dataset', 'test'), - test_labels=test_labels.values(), - ) - - for (x, y) in data.train_dataloader(): - assert len(x) == 2 - assert sorted(list(y.numpy())) == sorted(list(train_labels.values())[:B]) - - for (x, y) in data.val_dataloader(): - assert len(x) == 2 - assert sorted(list(y.numpy())) == sorted(list(val_labels.values())[:B]) - - for (x, y) in data.test_dataloader(): - assert len(x) == 2 - assert sorted(list(y.numpy())) == sorted(list(test_labels.values())[:B]) - - def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -295,7 +235,7 @@ def test_from_folders_only_train(tmpdir): img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(img_data.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) @@ -324,18 +264,18 @@ def test_from_folders_train_val(tmpdir): ) data = next(iter(img_data.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) data = next(iter(img_data.val_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [0, 0] data = next(iter(img_data.test_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [0, 0] @@ -353,30 +293,30 @@ def test_from_filepaths_multilabel(tmpdir): valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]] test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]] - dm = ImageClassificationData.from_filepaths( - train_filepaths=train_images, - train_labels=train_labels, - val_filepaths=train_images, - val_labels=valid_labels, - test_filepaths=train_images, - test_labels=test_labels, + dm = ImageClassificationData.from_files( + train_files=train_images, + train_targets=train_labels, + val_files=train_images, + val_targets=valid_labels, + test_files=train_images, + test_targets=test_labels, batch_size=2, num_workers=0, ) data = next(iter(dm.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) data = next(iter(dm.val_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) data = next(iter(dm.test_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(test_labels)) From 6b587fefd8eda17094f14e1a61190be07e9b55d2 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 17:41:45 +0100 Subject: [PATCH 49/78] Fix video --- tests/video/test_video_classifier.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/video/test_video_classifier.py b/tests/video/test_video_classifier.py index a5c3db023f..e4ed5cae88 100644 --- a/tests/video/test_video_classifier.py +++ b/tests/video/test_video_classifier.py @@ -105,15 +105,15 @@ def test_image_classifier_finetune(tmpdir): half_duration = total_duration / 2 - 1e-9 - datamodule = VideoClassificationData.from_paths( - train_data_path=mock_csv, + datamodule = VideoClassificationData.from_folders( + train_folder=mock_csv, clip_sampler="uniform", clip_duration=half_duration, video_sampler=SequentialSampler, decode_audio=False, ) - for sample in datamodule.train_dataset.dataset: + for sample in datamodule.train_dataset.data: expected_t_shape = 5 assert sample["video"].shape[1] == expected_t_shape @@ -144,8 +144,8 @@ def test_image_classifier_finetune(tmpdir): ]), } - datamodule = VideoClassificationData.from_paths( - train_data_path=mock_csv, + datamodule = VideoClassificationData.from_folders( + train_folder=mock_csv, clip_sampler="uniform", clip_duration=half_duration, video_sampler=SequentialSampler, From 4215a4783575aa5ee175bd07e64ab9276921c218 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 6 May 2021 19:03:25 +0200 Subject: [PATCH 50/78] fix test_from_filepaths_splits --- tests/vision/classification/test_data.py | 25 +++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index d79817d7ae..cef10e4389 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -18,9 +18,13 @@ import kornia as K import numpy as np import torch +import torch.nn as nn +import torchvision from PIL import Image +from flash.data.data_source import DefaultDataKeys from flash.data.data_utils import labels_from_categorical_csv +from flash.data.transforms import ApplyToKeys from flash.vision import ImageClassificationData @@ -192,18 +196,17 @@ def test_from_filepaths_splits(tmpdir): assert len(train_filepaths) == len(train_labels) - def preprocess(x): - out = K.image_to_tensor(np.array(x)) - return out - _to_tensor = { - "to_tensor_transform": lambda x: preprocess(x), + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor) + ), } def run(transform: Any = None): - img_data = ImageClassificationData.from_filepaths( - train_filepaths=train_filepaths, - train_labels=train_labels, + dm = ImageClassificationData.from_files( + train_files=train_filepaths, + train_targets=train_labels, train_transform=transform, val_transform=transform, batch_size=B, @@ -211,12 +214,12 @@ def run(transform: Any = None): val_split=val_split, image_size=img_size, ) - data = next(iter(img_data.train_dataloader())) - imgs, labels = data + data = next(iter(dm.train_dataloader())) + imgs, labels = data['input'], data['target'] assert imgs.shape == (B, 3, H, W) assert labels.shape == (B, ) - run() + #run() run(_to_tensor) From 0256c04f405ffd4a7b914ba13bb0f273c6b26726 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 18:14:10 +0100 Subject: [PATCH 51/78] Fixes --- flash/data/base_viz.py | 4 ++-- flash/data/utils.py | 5 +++++ flash/vision/classification/data.py | 11 +++++------ tests/vision/classification/test_data.py | 24 ++++++++++++++---------- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index c05cc93dcc..3ad1506257 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -5,7 +5,7 @@ from flash.core.utils import _is_overriden from flash.data.callback import BaseDataFetcher -from flash.data.utils import _PREPROCESS_FUNCS +from flash.data.utils import _CALLBACK_FUNCS class BaseVisualization(BaseDataFetcher): @@ -103,7 +103,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_li Override this function when you want to visualize a composition. """ # filter out the functions to visualise - func_names_set: Set[str] = set(func_names_list) & set(_PREPROCESS_FUNCS) + func_names_set: Set[str] = set(func_names_list) & set(_CALLBACK_FUNCS) if len(func_names_set) == 0: raise MisconfigurationException(f"Invalid function names: {func_names_list}.") diff --git a/flash/data/utils.py b/flash/data/utils.py index 0210a6fafd..bf69611f2f 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -48,6 +48,11 @@ "collate", } +_CALLBACK_FUNCS: Set[str] = { + "load_sample", + *_PREPROCESS_FUNCS, +} + _POSTPROCESS_FUNCS: Set[str] = { "per_batch_transform", "uncollate", diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index c212e5e0cc..230c0ad417 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -23,7 +23,7 @@ from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule -from flash.data.data_source import DefaultDataSources +from flash.data.data_source import DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE from flash.vision.classification.transforms import default_train_transforms, default_val_transforms @@ -134,10 +134,9 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) for i, ax in enumerate(axs.ravel()): # unpack images and labels if isinstance(data, list): - _img, _label = data[i] - elif isinstance(data, tuple): - imgs, labels = data - _img, _label = imgs[i], labels[i] + _img, _label = data[i][DefaultDataKeys.INPUT], data[i][DefaultDataKeys.TARGET] + elif isinstance(data, dict): + _img, _label = data[DefaultDataKeys.INPUT][i], data[DefaultDataKeys.TARGET][i] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images to numpy @@ -168,4 +167,4 @@ def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningS def show_per_batch_transform(self, batch: List[Any], running_stage): win_title: str = f"{running_stage} - show_per_batch_transform" - self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title) + self._show_images_and_labels(batch[0], batch[0][DefaultDataKeys.INPUT].shape[0], win_title) diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index d79817d7ae..2d475b379c 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -150,16 +150,20 @@ def test_from_filepaths_visualise_multilabel(tmpdir): (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "b" / "b_1.png") - - dm = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[[0, 1, 0], [0, 1, 1]], - val_filepaths=[tmpdir / "b", tmpdir / "a"], - val_labels=[[1, 1, 0], [0, 0, 1]], - test_filepaths=[tmpdir / "b", tmpdir / "b"], - test_labels=[[0, 0, 1], [1, 1, 0]], + + image_a = str(tmpdir / "a" / "a_1.png") + image_b = str(tmpdir / "b" / "b_1.png") + + _rand_image().save(image_a) + _rand_image().save(image_b) + + dm = ImageClassificationData.from_files( + train_files=[image_a, image_b], + train_targets=[[0, 1, 0], [0, 1, 1]], + val_files=[image_b, image_a], + val_targets=[[1, 1, 0], [0, 0, 1]], + test_files=[image_b, image_b], + test_targets=[[0, 0, 1], [1, 1, 0]], batch_size=2, ) # disable visualisation for testing From 1d5c41bc759d31afd8209b2cae6edc2fb91622bd Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 18:49:35 +0100 Subject: [PATCH 52/78] Fixes --- flash/core/model.py | 7 ++++--- flash/data/data_source.py | 3 +++ flash/vision/detection/data.py | 8 +++++--- flash/vision/detection/model.py | 11 ++++++++--- tests/vision/classification/test_data.py | 1 + .../classification/test_data_model_integration.py | 15 +++++++++------ tests/vision/classification/test_model.py | 15 +++++++++++---- tests/vision/detection/test_data.py | 9 +++++---- .../detection/test_data_model_integration.py | 2 +- 9 files changed, 47 insertions(+), 24 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 63c47874d8..0d88526cb6 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -324,9 +324,10 @@ def build_data_pipeline( data_source = data_source or old_data_source if isinstance(data_source, str): - assert preprocess is not None, type(preprocess) - # TODO: somehow the preprocess is not well generated when is a Default type - data_source = preprocess.data_source_of_name(data_source) + if preprocess is None: + data_source = DataSource() # TODO: warn the user that we are not using the specified data source + else: + data_source = preprocess.data_source_of_name(data_source) data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index b4846a74ea..d0d16c856f 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -106,6 +106,9 @@ def generate_dataset( if not is_none: from flash.data.data_pipeline import DataPipeline + if not isinstance(data, Sequence): + data = [data] + mock_dataset = MockDataset() with CurrentRunningStageFuncContext(running_stage, "load_data", self): load_data = getattr( diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 5cb026f7b5..590e7f9a83 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -18,9 +18,10 @@ from torchvision.datasets.folder import default_loader from flash.data.data_module import DataModule -from flash.data.data_source import DataSource +from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess from flash.utils.imports import _COCO_AVAILABLE +from flash.vision.data import ImagePathsDataSource from flash.vision.detection.transforms import default_transforms if _COCO_AVAILABLE: @@ -84,7 +85,7 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq return data def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - sample['input'] = default_loader(sample['input']) + sample[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) return sample @@ -103,9 +104,10 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ + DefaultDataSources.PATHS: ImagePathsDataSource(), "coco": COCODataSource(), }, - default_data_source="coco", + default_data_source=DefaultDataSources.PATHS, ) def collate(self, samples: Any) -> Any: diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index 9204c094b2..e2794bc12c 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -24,6 +24,7 @@ from flash.core import Task from flash.core.registry import FlashRegistry +from flash.data.data_source import DefaultDataKeys from flash.vision.backbones import OBJ_DETECTION_BACKBONES from flash.vision.detection.finetuning import ObjectDetectionFineTuning @@ -156,7 +157,7 @@ def get_model( def training_step(self, batch, batch_idx) -> Any: """The training step. Overrides ``Task.training_step`` """ - images, targets = batch['input'], batch['target'] + images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] targets = [{k: v for k, v in t.items()} for t in targets] # fasterrcnn takes both images and targets for training, returns loss_dict @@ -166,7 +167,7 @@ def training_step(self, batch, batch_idx) -> Any: return loss def validation_step(self, batch, batch_idx): - images, targets = batch['input'], batch['target'] + images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] # fasterrcnn takes only images for eval() mode outs = self.model(images) iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() @@ -178,7 +179,7 @@ def validation_epoch_end(self, outs): return {"avg_val_iou": avg_iou, "log": logs} def test_step(self, batch, batch_idx): - images, targets = batch['input'], batch['target'] + images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] # fasterrcnn takes only images for eval() mode outs = self.model(images) iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() @@ -189,5 +190,9 @@ def test_epoch_end(self, outs): logs = {"test_iou": avg_iou} return {"avg_test_iou": avg_iou, "log": logs} + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + def configure_finetune_callback(self): return [ObjectDetectionFineTuning(train_bn=True)] diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index c76ed75a5b..5154ef4062 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -169,6 +169,7 @@ def test_from_filepaths_visualise_multilabel(tmpdir): test_files=[image_b, image_b], test_targets=[[0, 0, 1], [1, 1, 0]], batch_size=2, + image_size=(64, 64), ) # disable visualisation for testing assert dm.data_fetcher.block_viz_window is True diff --git a/tests/vision/classification/test_data_model_integration.py b/tests/vision/classification/test_data_model_integration.py index 4bd70455ec..426438be72 100644 --- a/tests/vision/classification/test_data_model_integration.py +++ b/tests/vision/classification/test_data_model_integration.py @@ -34,13 +34,16 @@ def test_classification(tmpdir): (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "b" / "a_1.png") - data = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[0, 1], - train_transform={"per_batch_transform": lambda x: x}, + image_a = str(tmpdir / "a" / "a_1.png") + image_b = str(tmpdir / "b" / "b_1.png") + + _rand_image().save(image_a) + _rand_image().save(image_b) + + data = ImageClassificationData.from_files( + train_files=[image_a, image_b], + train_targets=[0, 1], num_workers=0, batch_size=2, ) diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index 0aa3ab1835..ee77451e9f 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -17,6 +17,7 @@ from flash import Trainer from flash.core.classification import Probabilities +from flash.data.data_source import DefaultDataKeys from flash.vision import ImageClassifier # ======== Mock functions ======== @@ -25,7 +26,10 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): - return torch.rand(3, 224, 224), torch.randint(10, size=(1, )).item() + return { + DefaultDataKeys.INPUT: torch.rand(3, 224, 224), + DefaultDataKeys.TARGET: torch.randint(10, size=(1, )).item(), + } def __len__(self) -> int: return 100 @@ -37,7 +41,10 @@ def __init__(self, num_classes: int): self.num_classes = num_classes def __getitem__(self, index): - return torch.rand(3, 224, 224), torch.randint(0, 2, (self.num_classes, )) + return { + DefaultDataKeys.INPUT: torch.rand(3, 224, 224), + DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes, )), + } def __len__(self) -> int: return 100 @@ -90,8 +97,8 @@ def test_multilabel(tmpdir): train_dl = torch.utils.data.DataLoader(ds, batch_size=2) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze") - image, label = ds[0] - predictions = model.predict(image.unsqueeze(0)) + image, label = ds[0][DefaultDataKeys.INPUT], ds[0][DefaultDataKeys.TARGET] + predictions = model.predict({DefaultDataKeys.INPUT: image}) assert (torch.tensor(predictions) > 1).sum() == 0 assert (torch.tensor(predictions) < 0).sum() == 0 assert len(predictions[0]) == num_classes == len(label) diff --git a/tests/vision/detection/test_data.py b/tests/vision/detection/test_data.py index fec4b9a5e8..39f8a191eb 100644 --- a/tests/vision/detection/test_data.py +++ b/tests/vision/detection/test_data.py @@ -6,6 +6,7 @@ from PIL import Image from pytorch_lightning.utilities import _module_available +from flash.data.data_source import DefaultDataKeys from flash.utils.imports import _COCO_AVAILABLE from flash.vision.detection.data import ObjectDetectionData @@ -83,7 +84,7 @@ def test_image_detector_data_from_coco(tmpdir): datamodule = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) data = next(iter(datamodule.train_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) @@ -101,11 +102,11 @@ def test_image_detector_data_from_coco(tmpdir): test_folder=train_folder, test_ann_file=coco_ann_path, batch_size=1, - num_workers=0 + num_workers=0, ) data = next(iter(datamodule.val_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) @@ -113,7 +114,7 @@ def test_image_detector_data_from_coco(tmpdir): assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] data = next(iter(datamodule.test_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) diff --git a/tests/vision/detection/test_data_model_integration.py b/tests/vision/detection/test_data_model_integration.py index 8f90279959..8c71115671 100644 --- a/tests/vision/detection/test_data_model_integration.py +++ b/tests/vision/detection/test_data_model_integration.py @@ -43,5 +43,5 @@ def test_detection(tmpdir, model, backbone): Image.new('RGB', (512, 512)).save(test_image_one) Image.new('RGB', (512, 512)).save(test_image_two) - test_images = [test_image_one, test_image_two] + test_images = [str(test_image_one), str(test_image_two)] model.predict(test_images) From 8064c65570fd9c07854fbb6475ce969d0fdd41e4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 18:57:53 +0100 Subject: [PATCH 53/78] Fixes --- flash/vision/detection/model.py | 4 ++-- tests/vision/detection/test_model.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index e2794bc12c..dba922b8f9 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -191,8 +191,8 @@ def test_epoch_end(self, outs): return {"avg_test_iou": avg_iou, "log": logs} def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = (batch[DefaultDataKeys.INPUT]) - return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + images = batch[DefaultDataKeys.INPUT] + return self.model(images) def configure_finetune_callback(self): return [ObjectDetectionFineTuning(train_bn=True)] diff --git a/tests/vision/detection/test_model.py b/tests/vision/detection/test_model.py index 110b55d43c..90fc9a1295 100644 --- a/tests/vision/detection/test_model.py +++ b/tests/vision/detection/test_model.py @@ -16,11 +16,12 @@ from pytorch_lightning import Trainer from torch.utils.data import DataLoader, Dataset +from flash.data.data_source import DefaultDataKeys from flash.vision import ObjectDetector -def collate_fn(batch): - return tuple(zip(*batch)) +def collate_fn(samples): + return {key: [sample[key] for sample in samples] for key in samples[0]} class DummyDetectionDataset(Dataset): @@ -45,7 +46,7 @@ def __getitem__(self, idx): img = torch.rand(self.img_shape) boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) labels = torch.randint(self.num_classes, (self.num_boxes, )) - return img, {"boxes": boxes, "labels": labels} + return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: {"boxes": boxes, "labels": labels}} def test_init(): @@ -55,7 +56,8 @@ def test_init(): batch_size = 2 ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size) - img, target = next(iter(dl)) + data = next(iter(dl)) + img, target = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] out = model(img) From a32560c7a69dbca1674585c5acc00a6b75b726a9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 19:00:44 +0100 Subject: [PATCH 54/78] Fixes --- tests/vision/classification/test_data.py | 1 - tests/vision/detection/test_model.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index 5154ef4062..19f49b672a 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -224,7 +224,6 @@ def run(transform: Any = None): assert imgs.shape == (B, 3, H, W) assert labels.shape == (B, ) - #run() run(_to_tensor) diff --git a/tests/vision/detection/test_model.py b/tests/vision/detection/test_model.py index 90fc9a1295..9d3f0a5dc6 100644 --- a/tests/vision/detection/test_model.py +++ b/tests/vision/detection/test_model.py @@ -57,7 +57,7 @@ def test_init(): ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size) data = next(iter(dl)) - img, target = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + img = data[DefaultDataKeys.INPUT] out = model(img) From 4d34d94492d86cefd909653c597e770d5f42496c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 19:02:40 +0100 Subject: [PATCH 55/78] Fixes --- flash_examples/predict/image_classification.py | 2 +- flash_examples/predict/image_classification_multi_label.py | 4 +++- flash_examples/predict/summarization.py | 2 +- flash_examples/predict/tabular_classification.py | 2 +- flash_examples/predict/text_classification.py | 2 +- flash_examples/predict/translation.py | 2 +- 6 files changed, 8 insertions(+), 6 deletions(-) diff --git a/flash_examples/predict/image_classification.py b/flash_examples/predict/image_classification.py index 3f3f90fef7..fe697b2963 100644 --- a/flash_examples/predict/image_classification.py +++ b/flash_examples/predict/image_classification.py @@ -19,7 +19,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the model from a checkpoint -model = ImageClassifier.load_from_checkpoint("../finetuning/image_classification_model.pt") +model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") # 3a. Predict what's on a few images! ants or bees? predictions = model.predict([ diff --git a/flash_examples/predict/image_classification_multi_label.py b/flash_examples/predict/image_classification_multi_label.py index 77b5f35978..59e4c7da9e 100644 --- a/flash_examples/predict/image_classification_multi_label.py +++ b/flash_examples/predict/image_classification_multi_label.py @@ -40,7 +40,9 @@ def show_per_batch_transform(self, batch: Any, _) -> None: # 3. Load the model from a checkpoint -model = ImageClassifier.load_from_checkpoint("../finetuning/image_classification_multi_label_model.pt", ) +model = ImageClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt", +) # 4a. Predict the genres of a few movie posters! predictions = model.predict([ diff --git a/flash_examples/predict/summarization.py b/flash_examples/predict/summarization.py index 3acaac05a9..ff59c6cfa3 100644 --- a/flash_examples/predict/summarization.py +++ b/flash_examples/predict/summarization.py @@ -20,7 +20,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the model from a checkpoint -model = SummarizationTask.load_from_checkpoint("../finetuning/summarization_model_xsum.pt") +model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") # 2a. Summarize an article! predictions = model.predict([ diff --git a/flash_examples/predict/tabular_classification.py b/flash_examples/predict/tabular_classification.py index 88fb569c16..a874d1f99f 100644 --- a/flash_examples/predict/tabular_classification.py +++ b/flash_examples/predict/tabular_classification.py @@ -19,7 +19,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") # 2. Load the model from a checkpoint -model = TabularClassifier.load_from_checkpoint("../finetuning/tabular_classification_model.pt") +model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt") model.serializer = Labels(['Did not survive', 'Survived']) diff --git a/flash_examples/predict/text_classification.py b/flash_examples/predict/text_classification.py index 705fee9f92..372250b21f 100644 --- a/flash_examples/predict/text_classification.py +++ b/flash_examples/predict/text_classification.py @@ -21,7 +21,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the model from a checkpoint -model = TextClassifier.load_from_checkpoint("../finetuning/text_classification_model.pt") +model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") model.serializer = Labels() diff --git a/flash_examples/predict/translation.py b/flash_examples/predict/translation.py index 112658ad33..ed1498232f 100644 --- a/flash_examples/predict/translation.py +++ b/flash_examples/predict/translation.py @@ -20,7 +20,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") # 2. Load the model from a checkpoint -model = TranslationTask.load_from_checkpoint("../finetuning/translation_model_en_ro.pt") +model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") # 3. Translate a few sentences! predictions = model.predict([ From c85a8dbe8a53d835cb362e4d91345f52484af810 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 19:17:19 +0100 Subject: [PATCH 56/78] Fixes --- flash/data/data_source.py | 3 --- tests/vision/classification/test_model.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index d0d16c856f..b4846a74ea 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -106,9 +106,6 @@ def generate_dataset( if not is_none: from flash.data.data_pipeline import DataPipeline - if not isinstance(data, Sequence): - data = [data] - mock_dataset = MockDataset() with CurrentRunningStageFuncContext(running_stage, "load_data", self): load_data = getattr( diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index ee77451e9f..94e26d889d 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -98,7 +98,7 @@ def test_multilabel(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze") image, label = ds[0][DefaultDataKeys.INPUT], ds[0][DefaultDataKeys.TARGET] - predictions = model.predict({DefaultDataKeys.INPUT: image}) + predictions = model.predict([{DefaultDataKeys.INPUT: image}]) assert (torch.tensor(predictions) > 1).sum() == 0 assert (torch.tensor(predictions) < 0).sum() == 0 assert len(predictions[0]) == num_classes == len(label) From c93a649e0a02a19433795d002ed01a492284c4ed Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 20:52:09 +0100 Subject: [PATCH 57/78] Fixes --- flash/data/data_source.py | 2 +- flash_examples/custom_task.py | 24 ++-- tests/data/test_auto_dataset.py | 115 ++---------------- .../test_data_model_integration.py | 1 + 4 files changed, 26 insertions(+), 116 deletions(-) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index b4846a74ea..0e16f43427 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -110,7 +110,7 @@ def generate_dataset( with CurrentRunningStageFuncContext(running_stage, "load_data", self): load_data = getattr( self, DataPipeline._resolve_function_hierarchy( - 'load_data', + "load_data", self, running_stage, DataSource, diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py index 4c8e5b9c04..ff5c395d82 100644 --- a/flash_examples/custom_task.py +++ b/flash_examples/custom_task.py @@ -9,6 +9,7 @@ import flash from flash.data.auto_dataset import AutoDataset +from flash.data.data_source import DataSource from flash.data.process import Postprocess, Preprocess seed_everything(42) @@ -41,22 +42,28 @@ def forward(self, x): return self.model(x) -class NumpyPreprocess(Preprocess): +class NumpyDataSource(DataSource): def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]: if self.training: dataset.num_inputs = data[0].shape[1] return [(x, y) for x, y in zip(*data)] + def predict_load_data(self, data: ND) -> ND: + return data + + +class NumpyPreprocess(Preprocess): + + def __init__(self): + super().__init__(data_sources={"numpy": NumpyDataSource()}, default_data_source="numpy") + def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]: x, y = sample x = torch.from_numpy(x).float() y = torch.tensor(y, dtype=torch.float) return x, y - def predict_load_data(self, data: ND) -> ND: - return data - def predict_to_tensor_transform(self, sample: ND) -> ND: return torch.from_numpy(sample).float() @@ -77,12 +84,13 @@ def from_dataset(cls, x: ND, y: ND, preprocess: Preprocess, batch_size: int = 64 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0) - dm = cls.from_load_data_inputs( - train_load_data_input=(x_train, y_train), - test_load_data_input=(x_test, y_test), + dm = cls.from_data_source( + "numpy", + train_data=(x_train, y_train), + test_data=(x_test, y_test), preprocess=preprocess, batch_size=batch_size, - num_workers=num_workers + num_workers=num_workers, ) dm.num_inputs = dm.train_dataset.num_inputs return dm diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index ca57b41329..b235c3cf38 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -23,7 +23,7 @@ from flash.data.process import Preprocess -class _AutoDatasetTestPreprocess(Preprocess): +class _AutoDatasetTestDataSource(DataSource): def __init__(self, with_dset: bool): self._callbacks: List[FlashCallback] = [] @@ -49,13 +49,6 @@ def __init__(self, with_dset: bool): self.train_load_data = self.train_load_data_no_dset self.train_load_sample = self.train_load_sample_no_dset - def get_state_dict(self) -> Dict[str, Any]: - return {"with_dset": self.with_dset} - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return _AutoDatasetTestPreprocess(state_dict["with_dset"]) - def load_data_no_dset(self, data): self.load_data_count += 1 return data @@ -171,61 +164,6 @@ def test_iterable_autodataset_smoke(): assert next(itr) == 2 -# TODO: do we remove ? -@pytest.mark.parametrize( - "with_dataset,with_running_stage", - [ - (True, False), - (True, True), - (False, False), - (False, True), - ], -) -def test_autodataset_with_functions( - with_dataset: bool, - with_running_stage: bool, -): - - functions = _AutoDatasetTestPreprocess(with_dataset) - - load_sample_func = functions.load_sample - load_data_func = functions.load_data - - if with_running_stage: - running_stage = RunningStage.TRAINING - else: - running_stage = None - dset = AutoDataset( - range(10), - load_data=load_data_func, - load_sample=load_sample_func, - running_stage=running_stage, - ) - - assert len(dset) == 10 - - for idx in range(len(dset)): - dset[idx] - - if with_dataset: - assert dset.load_sample_was_called - assert dset.load_data_was_called - assert functions.load_sample_with_dataset_count == len(dset) - assert functions.load_data_with_dataset_count == 1 - else: - assert functions.load_data_count == 1 - assert functions.load_sample_count == len(dset) - - -# TODO: do we remove ? -def test_autodataset_warning(): - with pytest.warns( - UserWarning, match="``datapipeline`` is specified but load_sample and/or load_data are also specified" - ): - AutoDataset(range(10), load_data=lambda x: x, load_sample=lambda x: x, data_pipeline=DataPipeline()) - - -# TODO: do we remove ? @pytest.mark.parametrize( "with_dataset", [ @@ -233,12 +171,11 @@ def test_autodataset_warning(): False, ], ) -def test_preprocessing_data_pipeline_with_running_stage(with_dataset): - pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) - +def test_preprocessing_data_source_with_running_stage(with_dataset): + data_source = _AutoDatasetTestDataSource(with_dataset) running_stage = RunningStage.TRAINING - dataset = pipe._generate_auto_dataset(range(10), running_stage=running_stage) + dataset = data_source.generate_dataset(range(10), running_stage=running_stage) assert len(dataset) == 10 @@ -248,44 +185,8 @@ def test_preprocessing_data_pipeline_with_running_stage(with_dataset): if with_dataset: assert dataset.train_load_sample_was_called assert dataset.train_load_data_was_called - assert pipe._preprocess_pipeline.train_load_sample_with_dataset_count == len(dataset) - assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 - else: - assert pipe._preprocess_pipeline.train_load_sample_count == len(dataset) - assert pipe._preprocess_pipeline.train_load_data_count == 1 - - -# TODO: do we remove ? we are testing DataPipeline here. -@pytest.mark.parametrize( - "with_dataset", - [ - True, - False, - ], -) -def test_preprocessing_data_pipeline_no_running_stage(with_dataset): - pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) - - dataset = pipe._generate_auto_dataset(range(10), running_stage=None) - - with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'): - for idx in range(len(dataset)): - dataset[idx] - - # will be triggered when running stage is set - if with_dataset: - assert not hasattr(dataset, 'load_sample_was_called') - assert not hasattr(dataset, 'load_data_was_called') - assert pipe._preprocess_pipeline.load_sample_with_dataset_count == 0 - assert pipe._preprocess_pipeline.load_data_with_dataset_count == 0 - else: - assert pipe._preprocess_pipeline.load_sample_count == 0 - assert pipe._preprocess_pipeline.load_data_count == 0 - - dataset.running_stage = RunningStage.TRAINING - - if with_dataset: - assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 - assert dataset.train_load_data_was_called + assert data_source.train_load_sample_with_dataset_count == len(dataset) + assert data_source.train_load_data_with_dataset_count == 1 else: - assert pipe._preprocess_pipeline.train_load_data_count == 1 + assert data_source.train_load_sample_count == len(dataset) + assert data_source.train_load_data_count == 1 diff --git a/tests/vision/classification/test_data_model_integration.py b/tests/vision/classification/test_data_model_integration.py index 426438be72..2425e3f760 100644 --- a/tests/vision/classification/test_data_model_integration.py +++ b/tests/vision/classification/test_data_model_integration.py @@ -46,6 +46,7 @@ def test_classification(tmpdir): train_targets=[0, 1], num_workers=0, batch_size=2, + image_size=(64, 64), ) model = ImageClassifier(num_classes=2, backbone="resnet18") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) From 02fd77b0fa3823293b7b0dc860aa6f23b0016d53 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 20:59:18 +0100 Subject: [PATCH 58/78] Fixes --- flash/vision/embedding/model.py | 19 ++++++++++++++++++- tests/data/test_callbacks.py | 4 ++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/flash/vision/embedding/model.py b/flash/vision/embedding/model.py index f43dabcfaa..1392228e37 100644 --- a/flash/vision/embedding/model.py +++ b/flash/vision/embedding/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union import torch from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -21,6 +21,7 @@ from flash.core import Task from flash.core.registry import FlashRegistry +from flash.data.data_source import DefaultDataKeys from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES from flash.vision.classification.data import ImageClassificationData, ImageClassificationPreprocess @@ -108,3 +109,19 @@ def forward(self, x) -> torch.Tensor: x = self.head(x) return x + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 057920a8c2..47a118c518 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -27,7 +27,7 @@ from flash.data.data_module import DataModule from flash.data.data_source import DefaultDataKeys from flash.data.process import DefaultPreprocess -from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX +from flash.data.utils import _CALLBACK_FUNCS, _STAGES_PREFIX from flash.vision import ImageClassificationData @@ -154,7 +154,7 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: for stage in _STAGES_PREFIX.values(): for _ in range(num_tests): - for fcn_name in _PREPROCESS_FUNCS: + for fcn_name in _CALLBACK_FUNCS: fcn = getattr(dm, f"show_{stage}_batch") fcn(fcn_name, reset=True) From 3d780fa48e96863b993ff8319649f8a2876103ec Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 May 2021 21:17:28 +0100 Subject: [PATCH 59/78] Fix docs build --- docs/source/general/data.rst | 13 +++++++++++++ docs/source/reference/image_classification.rst | 4 ---- docs/source/reference/video_classification.rst | 2 -- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index 1eb89405b7..b88ae138ee 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -21,6 +21,8 @@ Here are common terms you need to be familiar with: - The :class:`~flash.data.data_module.DataModule` contains the dataset, transforms and dataloaders. * - :class:`~flash.data.data_pipeline.DataPipeline` - The :class:`~flash.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. + * - :class:`~flash.data.data_source.DataSource` + - The :class:`~flash.data.data_source.DataSource` provides a hook-based API for creating data sets. * - :class:`~flash.data.process.Preprocess` - The :class:`~flash.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic. The :class:`~flash.data.process.Preprocess` provides multiple hooks such as :meth:`~flash.data.process.Preprocess.load_data` @@ -275,6 +277,17 @@ Example:: API reference ************* +.. _data_source: + +DataSource +__________ + +.. autoclass:: flash.data.data_source.DataSource + :members: + + +---------- + .. _preprocess: Preprocess diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index ac12aea2cf..54f841ad5f 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -183,8 +183,4 @@ ImageClassificationData .. autoclass:: flash.vision.ImageClassificationData -.. automethod:: flash.vision.ImageClassificationData.from_filepaths - -.. automethod:: flash.vision.ImageClassificationData.from_folders - .. autoclass:: flash.vision.ImageClassificationPreprocess diff --git a/docs/source/reference/video_classification.rst b/docs/source/reference/video_classification.rst index e088a556ea..6b7d3c08d1 100644 --- a/docs/source/reference/video_classification.rst +++ b/docs/source/reference/video_classification.rst @@ -152,5 +152,3 @@ VideoClassificationData ----------------------- .. autoclass:: flash.video.VideoClassificationData - -.. automethod:: flash.video.VideoClassificationData.from_paths From 704f558bf0e0e64ddbee8312e5d47b0094a033e6 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 09:05:03 +0100 Subject: [PATCH 60/78] Fix docs build --- docs/source/general/data.rst | 1 - flash/data/data_pipeline.py | 3 +- flash/data/data_source.py | 2 +- flash/data/process.py | 106 +------------------------------ flash/data/properties.py | 120 +++++++++++++++++++++++++++++++++++ tests/data/test_process.py | 3 +- 6 files changed, 128 insertions(+), 107 deletions(-) create mode 100644 flash/data/properties.py diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index b88ae138ee..fec10ec9d6 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -338,7 +338,6 @@ __________ .. autoclass:: flash.data.data_module.DataModule :members: - from_load_data_inputs, train_dataset, val_dataset, test_dataset, diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index cbd6ba4700..2db825affa 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -27,7 +27,8 @@ from flash.data.auto_dataset import IterableAutoDataset from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential from flash.data.data_source import DataSource -from flash.data.process import DefaultPreprocess, Postprocess, Preprocess, ProcessState, Serializer +from flash.data.process import DefaultPreprocess, Postprocess, Preprocess, Serializer +from flash.data.properties import ProcessState from flash.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX if TYPE_CHECKING: diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 0e16f43427..b385b2db0c 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -24,7 +24,7 @@ from torchvision.datasets.folder import has_file_allowed_extension, make_dataset from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset -from flash.data.process import ProcessState, Properties +from flash.data.properties import ProcessState, Properties from flash.data.utils import CurrentRunningStageFuncContext diff --git a/flash/data/process.py b/flash/data/process.py index 01518b066f..e86095b12e 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -13,8 +13,7 @@ # limitations under the License. import os from abc import ABC, abstractclassmethod, abstractmethod -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, TypeVar, Union import torch from pytorch_lightning.trainer.states import RunningStage @@ -25,111 +24,12 @@ from flash.data.batch import default_uncollate from flash.data.callback import FlashCallback +from flash.data.data_source import DataSource +from flash.data.properties import Properties from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules if TYPE_CHECKING: from flash.data.data_pipeline import DataPipelineState - from flash.data.data_source import DataSource - - -@dataclass(unsafe_hash=True, frozen=True) -class ProcessState: - """ - Base class for all process states - """ - pass - - -STATE_TYPE = TypeVar('STATE_TYPE', bound=ProcessState) - - -class Properties: - - def __init__(self): - super().__init__() - - self._running_stage: Optional[RunningStage] = None - self._current_fn: Optional[str] = None - self._data_pipeline_state: Optional['DataPipelineState'] = None - self._state: Dict[Type[ProcessState], ProcessState] = {} - - def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]: - if state_type in self._state: - return self._state[state_type] - if self._data_pipeline_state is not None: - return self._data_pipeline_state.get_state(state_type) - else: - return None - - def set_state(self, state: ProcessState): - self._state[type(state)] = state - if self._data_pipeline_state is not None: - self._data_pipeline_state.set_state(state) - - def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'): - self._data_pipeline_state = data_pipeline_state - for state in self._state.values(): - self._data_pipeline_state.set_state(state) - - @property - def current_fn(self) -> Optional[str]: - return self._current_fn - - @current_fn.setter - def current_fn(self, current_fn: str): - self._current_fn = current_fn - - @property - def running_stage(self) -> Optional[RunningStage]: - return self._running_stage - - @running_stage.setter - def running_stage(self, running_stage: RunningStage): - self._running_stage = running_stage - - @property - def training(self) -> bool: - return self._running_stage == RunningStage.TRAINING - - @training.setter - def training(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TRAINING - elif self.training: - self._running_stage = None - - @property - def testing(self) -> bool: - return self._running_stage == RunningStage.TESTING - - @testing.setter - def testing(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TESTING - elif self.testing: - self._running_stage = None - - @property - def predicting(self) -> bool: - return self._running_stage == RunningStage.PREDICTING - - @predicting.setter - def predicting(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.PREDICTING - elif self.predicting: - self._running_stage = None - - @property - def validating(self) -> bool: - return self._running_stage == RunningStage.VALIDATING - - @validating.setter - def validating(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.VALIDATING - elif self.validating: - self._running_stage = None class BasePreprocess(ABC): diff --git a/flash/data/properties.py b/flash/data/properties.py new file mode 100644 index 0000000000..2a2934a3d9 --- /dev/null +++ b/flash/data/properties.py @@ -0,0 +1,120 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional, Type, TYPE_CHECKING, TypeVar + +from pytorch_lightning.trainer.states import RunningStage + +if TYPE_CHECKING: + from flash.data.data_pipeline import DataPipelineState + + +@dataclass(unsafe_hash=True, frozen=True) +class ProcessState: + """ + Base class for all process states + """ + pass + + +STATE_TYPE = TypeVar('STATE_TYPE', bound=ProcessState) + + +class Properties: + + def __init__(self): + super().__init__() + + self._running_stage: Optional[RunningStage] = None + self._current_fn: Optional[str] = None + self._data_pipeline_state: Optional['DataPipelineState'] = None + self._state: Dict[Type[ProcessState], ProcessState] = {} + + def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]: + if state_type in self._state: + return self._state[state_type] + if self._data_pipeline_state is not None: + return self._data_pipeline_state.get_state(state_type) + else: + return None + + def set_state(self, state: ProcessState): + self._state[type(state)] = state + if self._data_pipeline_state is not None: + self._data_pipeline_state.set_state(state) + + def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'): + self._data_pipeline_state = data_pipeline_state + for state in self._state.values(): + self._data_pipeline_state.set_state(state) + + @property + def current_fn(self) -> Optional[str]: + return self._current_fn + + @current_fn.setter + def current_fn(self, current_fn: str): + self._current_fn = current_fn + + @property + def running_stage(self) -> Optional[RunningStage]: + return self._running_stage + + @running_stage.setter + def running_stage(self, running_stage: RunningStage): + self._running_stage = running_stage + + @property + def training(self) -> bool: + return self._running_stage == RunningStage.TRAINING + + @training.setter + def training(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TRAINING + elif self.training: + self._running_stage = None + + @property + def testing(self) -> bool: + return self._running_stage == RunningStage.TESTING + + @testing.setter + def testing(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TESTING + elif self.testing: + self._running_stage = None + + @property + def predicting(self) -> bool: + return self._running_stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.PREDICTING + elif self.predicting: + self._running_stage = None + + @property + def validating(self) -> bool: + return self._running_stage == RunningStage.VALIDATING + + @validating.setter + def validating(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.VALIDATING + elif self.validating: + self._running_stage = None diff --git a/tests/data/test_process.py b/tests/data/test_process.py index 6f4d59f3d0..66df027b5b 100644 --- a/tests/data/test_process.py +++ b/tests/data/test_process.py @@ -21,7 +21,8 @@ from flash import Task, Trainer from flash.core.classification import Labels, LabelsState from flash.data.data_pipeline import DataPipeline, DataPipelineState, DefaultPreprocess -from flash.data.process import ProcessState, Properties, Serializer, SerializerMapping +from flash.data.process import Serializer, SerializerMapping +from flash.data.properties import ProcessState, Properties def test_properties_data_pipeline_state(): From edfc38e347c366f8aa5ea982cf323c53c54709ea Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 10:06:49 +0100 Subject: [PATCH 61/78] Fix examples --- flash/text/seq2seq/translation/data.py | 2 +- flash_examples/finetuning/image_classification.py | 2 +- flash_examples/finetuning/translation.py | 6 +++++- flash_examples/finetuning/video_classification.py | 3 +-- flash_examples/predict/translation.py | 4 +--- tests/examples/test_scripts.py | 4 ++-- 6 files changed, 11 insertions(+), 10 deletions(-) diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 32a91746b8..1475227086 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -24,7 +24,7 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - backbone: str = "facebook/mbart-large-en-ro", + backbone: str = "Helsinki-NLP/opus-mt-en-ro", max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length' diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index a8326e3670..2ebc668f95 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -48,7 +48,7 @@ def fn_resnet(pretrained: bool = True): print(ImageClassifier.available_backbones()) # 4. Build the model -model = ImageClassifier(backbone="dino_vitb16", num_classes=datamodule.num_classes) +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) # 5. Create the trainer. trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index 69440bed66..2a3b1bebf9 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -34,7 +34,11 @@ model = TranslationTask() # 4. Create the trainer -trainer = flash.Trainer(precision=32, gpus=int(torch.cuda.is_available()), fast_dev_run=True) +trainer = flash.Trainer( + precision=16 if torch.cuda.is_available() else 32, + gpus=int(torch.cuda.is_available()), + fast_dev_run=True, +) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) diff --git a/flash_examples/finetuning/video_classification.py b/flash_examples/finetuning/video_classification.py index 4efa815dee..78105f1c4e 100644 --- a/flash_examples/finetuning/video_classification.py +++ b/flash_examples/finetuning/video_classification.py @@ -79,7 +79,6 @@ def make_transform( train_transform=make_transform(train_post_tensor_transform), val_transform=make_transform(val_post_tensor_transform), predict_transform=make_transform(val_post_tensor_transform), - num_workers=8, batch_size=8, clip_sampler="uniform", clip_duration=2, @@ -97,7 +96,7 @@ def make_transform( model.serializer = Labels() # 6. Finetune the model - trainer = flash.Trainer(max_epochs=3) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) trainer.finetune(model, datamodule=datamodule, strategy=NoFreeze()) trainer.save_checkpoint("video_classification.pt") diff --git a/flash_examples/predict/translation.py b/flash_examples/predict/translation.py index ed1498232f..cd6009f4db 100644 --- a/flash_examples/predict/translation.py +++ b/flash_examples/predict/translation.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import Trainer - from flash.data.utils import download_data -from flash.text import TranslationData, TranslationTask +from flash.text import TranslationTask # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index ba5dd7d82b..2fc4ee18f3 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -62,7 +62,7 @@ def run_test(filepath): ("finetuning", "tabular_classification.py"), # ("finetuning", "video_classification.py"), # ("finetuning", "text_classification.py"), # TODO: takes too long - # ("finetuning", "translation.py"), # TODO: takes too long. + ("finetuning", "translation.py"), ("predict", "image_classification.py"), ("predict", "image_classification_multi_label.py"), ("predict", "tabular_classification.py"), @@ -70,7 +70,7 @@ def run_test(filepath): ("predict", "image_embedder.py"), ("predict", "video_classification.py"), # ("predict", "summarization.py"), # TODO: takes too long - # ("predict", "translate.py"), # TODO: takes too long + ("predict", "translation.py"), ] ) def test_example(tmpdir, folder, file): From 4679cb51dbe176d05acc985426fafc2521be2868 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 10:41:05 +0100 Subject: [PATCH 62/78] Fixes --- flash/text/seq2seq/translation/data.py | 2 +- flash/text/seq2seq/translation/model.py | 2 +- flash_examples/custom_task.py | 2 +- flash_examples/finetuning/video_classification.py | 2 +- tests/data/test_callbacks.py | 4 +--- 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 1475227086..057ce41869 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -24,7 +24,7 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - backbone: str = "Helsinki-NLP/opus-mt-en-ro", + backbone: str = "t5-small", max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length' diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 1ae64d3e11..9eba02d753 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -37,7 +37,7 @@ class TranslationTask(Seq2SeqTask): def __init__( self, - backbone: str = "Helsinki-NLP/opus-mt-en-ro", + backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py index ff5c395d82..8fc9c3de88 100644 --- a/flash_examples/custom_task.py +++ b/flash_examples/custom_task.py @@ -10,7 +10,7 @@ import flash from flash.data.auto_dataset import AutoDataset from flash.data.data_source import DataSource -from flash.data.process import Postprocess, Preprocess +from flash.data.process import Preprocess seed_everything(42) diff --git a/flash_examples/finetuning/video_classification.py b/flash_examples/finetuning/video_classification.py index 78105f1c4e..c9ede4f043 100644 --- a/flash_examples/finetuning/video_classification.py +++ b/flash_examples/finetuning/video_classification.py @@ -96,7 +96,7 @@ def make_transform( model.serializer = Labels() # 6. Finetune the model - trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) + trainer = flash.Trainer(max_epochs=3) trainer.finetune(model, datamodule=datamodule, strategy=NoFreeze()) trainer.save_checkpoint("video_classification.pt") diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 47a118c518..58eabf2dc6 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -40,7 +40,7 @@ def test_base_data_fetcher(tmpdir): class CheckData(BaseDataFetcher): def check(self): - assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] + # assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] # TODO: This fails only on some CI assert self.batches["val"]["pre_tensor_transform"] == [0, 1, 2, 3, 4] assert self.batches["val"]["to_tensor_transform"] == [0, 1, 2, 3, 4] assert self.batches["val"]["post_tensor_transform"] == [0, 1, 2, 3, 4] @@ -77,8 +77,6 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat with data_fetcher.enable(): _ = next(iter(dm.val_dataloader())) - # TODO: the method below fails because the data fetcher internally doesn't seem to cache - # properly the batches at each stage. data_fetcher.check() data_fetcher.reset() assert data_fetcher.batches == {'train': {}, 'test': {}, 'val': {}, 'predict': {}} From 05a1e98b173417da64c1a94e207efcd93deb4bb7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 10:56:04 +0100 Subject: [PATCH 63/78] Fixes --- flash/data/callback.py | 3 ++- tests/data/test_callbacks.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/flash/data/callback.py b/flash/data/callback.py index 1221046a31..8609868f7c 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -156,7 +156,8 @@ def __init__(self, enabled: bool = False): def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None: if self.enabled: store = self.batches[_STAGES_PREFIX[running_stage]] - store.setdefault(fn_name, []) + if fn_name not in store: + store[fn_name] = [] store[fn_name].append(data) def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 58eabf2dc6..ad07288ac0 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -40,7 +40,7 @@ def test_base_data_fetcher(tmpdir): class CheckData(BaseDataFetcher): def check(self): - # assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] # TODO: This fails only on some CI + assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] assert self.batches["val"]["pre_tensor_transform"] == [0, 1, 2, 3, 4] assert self.batches["val"]["to_tensor_transform"] == [0, 1, 2, 3, 4] assert self.batches["val"]["post_tensor_transform"] == [0, 1, 2, 3, 4] From 46b6a4f7eef7ff484f108f8a7034327c4523e12b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 11:13:44 +0100 Subject: [PATCH 64/78] Bump huggingface minimal --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8aaa1ec97d..39329407f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ pytorch-lightning>=1.3.0rc1 lightning-bolts>=0.3.3 PyYAML>=5.1 Pillow>=7.2 -transformers>=4.0 +transformers>=4.1 pytorch-tabnet==3.1 datasets>=1.2, <1.3 pandas>=1.1 From 5b2013ec481eec652ea4cf893a1a4b44578b17bb Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 11:19:46 +0100 Subject: [PATCH 65/78] debugging --- requirements.txt | 2 +- tests/data/test_callbacks.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 39329407f9..6a257aa5a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ pytorch-lightning>=1.3.0rc1 lightning-bolts>=0.3.3 PyYAML>=5.1 Pillow>=7.2 -transformers>=4.1 +transformers>=4.2 pytorch-tabnet==3.1 datasets>=1.2, <1.3 pandas>=1.1 diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index ad07288ac0..0c2b103e54 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -75,6 +75,7 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat data_fetcher: CheckData = dm.data_fetcher with data_fetcher.enable(): + assert data_fetcher.enabled _ = next(iter(dm.val_dataloader())) data_fetcher.check() From 75f3469a2b7065aa3e879c977d52433acf8827b3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 11:29:43 +0100 Subject: [PATCH 66/78] debugging --- requirements.txt | 2 +- tests/data/test_callbacks.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6a257aa5a3..bce2efd675 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ pytorch-lightning>=1.3.0rc1 lightning-bolts>=0.3.3 PyYAML>=5.1 Pillow>=7.2 -transformers>=4.2 +transformers>=4.5 pytorch-tabnet==3.1 datasets>=1.2, <1.3 pandas>=1.1 diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 0c2b103e54..8930ac853b 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -40,6 +40,7 @@ def test_base_data_fetcher(tmpdir): class CheckData(BaseDataFetcher): def check(self): + print(self.batches["val"]) assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] assert self.batches["val"]["pre_tensor_transform"] == [0, 1, 2, 3, 4] assert self.batches["val"]["to_tensor_transform"] == [0, 1, 2, 3, 4] From 950b13f03bafcf871510bfc71e33246303b602a7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 11:41:57 +0100 Subject: [PATCH 67/78] Fixes --- tests/data/test_callbacks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 8930ac853b..d8fb4c5391 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -40,7 +40,6 @@ def test_base_data_fetcher(tmpdir): class CheckData(BaseDataFetcher): def check(self): - print(self.batches["val"]) assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] assert self.batches["val"]["pre_tensor_transform"] == [0, 1, 2, 3, 4] assert self.batches["val"]["to_tensor_transform"] == [0, 1, 2, 3, 4] @@ -75,9 +74,12 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5)) data_fetcher: CheckData = dm.data_fetcher + if not hasattr(dm, "_val_iter"): + dm._reset_iterator("val") + with data_fetcher.enable(): assert data_fetcher.enabled - _ = next(iter(dm.val_dataloader())) + _ = next(dm._val_iter()) data_fetcher.check() data_fetcher.reset() From f47208ce6cda1c2ccf95f9e45e9453fe94ad4a8a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 11:42:44 +0100 Subject: [PATCH 68/78] Fixes --- tests/data/test_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index d8fb4c5391..f4748a5149 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -79,7 +79,7 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat with data_fetcher.enable(): assert data_fetcher.enabled - _ = next(dm._val_iter()) + _ = next(dm._val_iter) data_fetcher.check() data_fetcher.reset() From db0c991185a7ede325592707aeb1f4f24138af8b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 14:51:16 +0100 Subject: [PATCH 69/78] Respond to comments --- .../reference/tabular_classification.rst | 2 +- flash/core/model.py | 3 +- flash/data/auto_dataset.py | 4 +- flash/data/data_module.py | 27 ++++--- flash/data/data_source.py | 9 +-- flash/data/utils.py | 5 +- flash/tabular/classification/data/data.py | 78 ++++++++++--------- tests/tabular/data/test_data.py | 58 +++++++------- tests/tabular/test_data_model_integration.py | 14 ++-- 9 files changed, 102 insertions(+), 98 deletions(-) diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index e54356c751..9812bab90b 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -165,4 +165,4 @@ TabularData .. automethod:: flash.tabular.TabularData.from_csv -.. automethod:: flash.tabular.TabularData.from_df +.. automethod:: flash.tabular.TabularData.from_data_frame diff --git a/flash/core/model.py b/flash/core/model.py index 0d88526cb6..0b9657ec1a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -158,7 +158,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: def predict( self, x: Any, - data_source: Union[str] = "default", + data_source: str = "default", data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ @@ -177,7 +177,6 @@ def predict( data_pipeline = self.build_data_pipeline(data_source, data_pipeline) x = [x for x in data_pipeline._data_source.generate_dataset(x, running_stage)] - assert len(x) > 0, "List of inputs shouldn't be empty." x = data_pipeline.worker_preprocessor(running_stage)(x) # switch to self.device when #7188 merge in Lightning x = self.transfer_batch_to_device(x, next(self.parameters()).device) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index ad4131a78e..1755ea7c7b 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from inspect import signature -from typing import Any, Generic, Iterable, Sequence, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Generic, Iterable, Sequence, TYPE_CHECKING, TypeVar from pytorch_lightning.trainer.states import RunningStage from torch.utils.data import Dataset, IterableDataset @@ -62,7 +62,7 @@ def running_stage(self, running_stage: RunningStage) -> None: self._load_sample_context = CurrentRunningStageFuncContext(self.running_stage, "load_sample", self.data_source) - self.load_sample = getattr( + self.load_sample: Callable = getattr( self.data_source, DataPipeline._resolve_function_hierarchy( 'load_sample', diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 0e920e7be9..a47c4a148c 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import pathlib import platform from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -63,7 +62,7 @@ def __init__( data_fetcher: Optional[BaseDataFetcher] = None, val_split: Optional[float] = None, batch_size: int = 1, - num_workers: Optional[int] = 0, + num_workers: Optional[int] = None, ) -> None: super().__init__() @@ -385,10 +384,10 @@ def from_data_source( @classmethod def from_folders( cls, - train_folder: Optional[Union[str, pathlib.Path]] = None, - val_folder: Optional[Union[str, pathlib.Path]] = None, - test_folder: Optional[Union[str, pathlib.Path]] = None, - predict_folder: Union[str, pathlib.Path] = None, + train_folder: Optional[str] = None, + val_folder: Optional[str] = None, + test_folder: Optional[str] = None, + predict_folder: Optional[str] = None, train_transform: Optional[Union[str, Dict]] = 'default', val_transform: Optional[Union[str, Dict]] = 'default', test_transform: Optional[Union[str, Dict]] = 'default', @@ -421,13 +420,13 @@ def from_folders( @classmethod def from_files( cls, - train_files: Optional[Sequence[Union[str, pathlib.Path]]] = None, + train_files: Optional[Sequence[str]] = None, train_targets: Optional[Sequence[Any]] = None, - val_files: Optional[Sequence[Union[str, pathlib.Path]]] = None, + val_files: Optional[Sequence[str]] = None, val_targets: Optional[Sequence[Any]] = None, - test_files: Optional[Sequence[Union[str, pathlib.Path]]] = None, + test_files: Optional[Sequence[str]] = None, test_targets: Optional[Sequence[Any]] = None, - predict_files: Optional[Sequence[Union[str, pathlib.Path]]] = None, + predict_files: Optional[Sequence[str]] = None, train_transform: Optional[Union[str, Dict]] = 'default', val_transform: Optional[Union[str, Dict]] = 'default', test_transform: Optional[Union[str, Dict]] = 'default', @@ -538,8 +537,8 @@ def from_numpy( @classmethod def from_json( cls, - input_fields: Union[str, List[str]], - target_fields: Optional[Union[str, List[str]]] = None, + input_fields: Union[str, Sequence[str]], + target_fields: Optional[Union[str, Sequence[str]]] = None, train_file: Optional[str] = None, val_file: Optional[str] = None, test_file: Optional[str] = None, @@ -568,8 +567,8 @@ def from_json( @classmethod def from_csv( cls, - input_fields: Union[str, List[str]], - target_fields: Optional[Union[str, List[str]]] = None, + input_fields: Union[str, Sequence[str]], + target_fields: Optional[Union[str, Sequence[str]]] = None, train_file: Optional[str] = None, val_file: Optional[str] = None, test_file: Optional[str] = None, diff --git a/flash/data/data_source.py b/flash/data/data_source.py index b385b2db0c..40402f7f49 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -14,7 +14,7 @@ import os from dataclasses import dataclass from inspect import signature -from typing import Any, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union import numpy as np import torch @@ -97,9 +97,7 @@ def generate_dataset( running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: is_none = data is None - # TODO: we should parse better the possible data types here. - # Are `pata_paths` considered as Sequence ? for now it pass - # the statement found in below. + if isinstance(data, Sequence): is_none = data[0] is None @@ -108,7 +106,7 @@ def generate_dataset( mock_dataset = MockDataset() with CurrentRunningStageFuncContext(running_stage, "load_data", self): - load_data = getattr( + load_data: Callable = getattr( self, DataPipeline._resolve_function_hierarchy( "load_data", self, @@ -215,6 +213,7 @@ def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool: try: return os.path.isdir(data) except TypeError: + # data is not path-like (e.g. it may be a list of paths) return False def load_data(self, diff --git a/flash/data/utils.py b/flash/data/utils.py index bf69611f2f..9a329beb78 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -33,12 +33,11 @@ _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} _DATASOURCE_FUNCS: Set[str] = { - 'load_data', - 'load_sample', + "load_data", + "load_sample", } _PREPROCESS_FUNCS: Set[str] = { - # "load_sample", # TODO: This should still be a callback hook "pre_tensor_transform", "to_tensor_transform", "post_tensor_transform", diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index 4e3029e457..2969405275 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -140,7 +140,7 @@ def __init__( DefaultDataSources.CSV: TabularCSVDataSource( cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression ), - "df": TabularDataFrameDataSource( + "data_frame": TabularDataFrameDataSource( cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression ), }, @@ -211,51 +211,53 @@ def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]): @classmethod def compute_state( cls, - train_df: DataFrame, - val_df: Optional[DataFrame], - test_df: Optional[DataFrame], - predict_df: Optional[DataFrame], + train_data_frame: DataFrame, + val_data_frame: Optional[DataFrame], + test_data_frame: Optional[DataFrame], + predict_data_frame: Optional[DataFrame], target_col: str, num_cols: List[str], cat_cols: List[str], ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]: - if train_df is None: - raise MisconfigurationException("train_df is required to instantiate the TabularDataFrameDataSource") + if train_data_frame is None: + raise MisconfigurationException( + "train_data_frame is required to instantiate the TabularDataFrameDataSource" + ) - dfs = [train_df] + data_frames = [train_data_frame] - if val_df is not None: - dfs += [val_df] + if val_data_frame is not None: + data_frames += [val_data_frame] - if test_df is not None: - dfs += [test_df] + if test_data_frame is not None: + data_frames += [test_data_frame] - if predict_df is not None: - dfs += [predict_df] + if predict_data_frame is not None: + data_frames += [predict_data_frame] - mean, std = _compute_normalization(dfs[0], num_cols) - classes = list(dfs[0][target_col].unique()) + mean, std = _compute_normalization(data_frames[0], num_cols) + classes = list(data_frames[0][target_col].unique()) - if dfs[0][target_col].dtype == object: + if data_frames[0][target_col].dtype == object: # if the target_col is a category, not an int - target_codes = _generate_codes(dfs, [target_col]) + target_codes = _generate_codes(data_frames, [target_col]) else: target_codes = None - codes = _generate_codes(dfs, cat_cols) + codes = _generate_codes(data_frames, cat_cols) return mean, std, classes, codes, target_codes @classmethod - def from_df( + def from_data_frame( cls, categorical_cols: List, numerical_cols: List, target_col: str, - train_df: DataFrame, - val_df: Optional[DataFrame] = None, - test_df: Optional[DataFrame] = None, - predict_df: Optional[DataFrame] = None, + train_data_frame: DataFrame, + val_data_frame: Optional[DataFrame] = None, + test_data_frame: Optional[DataFrame] = None, + predict_data_frame: Optional[DataFrame] = None, is_regression: bool = False, preprocess: Optional[Preprocess] = None, val_split: float = None, @@ -288,15 +290,21 @@ def from_df( categorical_cols, numerical_cols = cls._sanetize_cols(categorical_cols, numerical_cols) mean, std, classes, codes, target_codes = cls.compute_state( - train_df, val_df, test_df, predict_df, target_col, numerical_cols, categorical_cols + train_data_frame, + val_data_frame, + test_data_frame, + predict_data_frame, + target_col, + numerical_cols, + categorical_cols, ) return cls.from_data_source( - data_source="df", - train_data=train_df, - val_data=val_df, - test_data=test_df, - predict_data=predict_df, + data_source="data_frame", + train_data=train_data_frame, + val_data=val_data_frame, + test_data=test_data_frame, + predict_data=predict_data_frame, preprocess=preprocess, val_split=val_split, batch_size=batch_size, @@ -328,14 +336,14 @@ def from_csv( batch_size: int = 4, num_workers: Optional[int] = None, ) -> 'DataModule': - return cls.from_df( + return cls.from_data_frame( categorical_fields, numerical_fields, target_field, - train_df=pd.read_csv(train_file) if train_file is not None else None, - val_df=pd.read_csv(val_file) if val_file is not None else None, - test_df=pd.read_csv(test_file) if test_file is not None else None, - predict_df=pd.read_csv(predict_file) if predict_file is not None else None, + train_data_frame=pd.read_csv(train_file) if train_file is not None else None, + val_data_frame=pd.read_csv(val_file) if val_file is not None else None, + test_data_frame=pd.read_csv(test_file) if test_file is not None else None, + predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None, is_regression=is_regression, preprocess=preprocess, val_split=val_split, diff --git a/tests/tabular/data/test_data.py b/tests/tabular/data/test_data.py index 99f9432c42..1a0d1e1574 100644 --- a/tests/tabular/data/test_data.py +++ b/tests/tabular/data/test_data.py @@ -83,16 +83,16 @@ def test_emb_sizes(): def test_tabular_data(tmpdir): - train_df = TEST_DF_1.copy() - val_df = TEST_DF_2.copy() - test_df = TEST_DF_2.copy() - dm = TabularData.from_df( + train_data_frame = TEST_DF_1.copy() + val_data_frame = TEST_DF_2.copy() + test_data_frame = TEST_DF_2.copy() + dm = TabularData.from_data_frame( categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", - train_df=train_df, - val_df=val_df, - test_df=test_df, + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, num_workers=0, batch_size=1, ) @@ -106,20 +106,20 @@ def test_tabular_data(tmpdir): def test_categorical_target(tmpdir): - train_df = TEST_DF_1.copy() - val_df = TEST_DF_2.copy() - test_df = TEST_DF_2.copy() - for df in [train_df, val_df, test_df]: + train_data_frame = TEST_DF_1.copy() + val_data_frame = TEST_DF_2.copy() + test_data_frame = TEST_DF_2.copy() + for df in [train_data_frame, val_data_frame, test_data_frame]: # change int label to string df["label"] = df["label"].astype(str) - dm = TabularData.from_df( + dm = TabularData.from_data_frame( categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", - train_df=train_df, - val_df=val_df, - test_df=test_df, + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, num_workers=0, batch_size=1, ) @@ -132,17 +132,17 @@ def test_categorical_target(tmpdir): assert target.shape == (1, ) -def test_from_df(tmpdir): - train_df = TEST_DF_1.copy() - val_df = TEST_DF_2.copy() - test_df = TEST_DF_2.copy() - dm = TabularData.from_df( +def test_from_data_frame(tmpdir): + train_data_frame = TEST_DF_1.copy() + val_data_frame = TEST_DF_2.copy() + test_data_frame = TEST_DF_2.copy() + dm = TabularData.from_data_frame( categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", - train_df=train_df, - val_df=val_df, - test_df=test_df, + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, num_workers=0, batch_size=1 ) @@ -166,9 +166,9 @@ def test_from_csv(tmpdir): categorical_fields=["category"], numerical_fields=["scalar_b", "scalar_b"], target_field="label", - train_file=train_csv, - val_file=val_csv, - test_file=test_csv, + train_file=str(train_csv), + val_file=str(val_csv), + test_file=str(test_csv), num_workers=0, batch_size=1 ) @@ -182,13 +182,13 @@ def test_from_csv(tmpdir): def test_empty_inputs(): - train_df = TEST_DF_1.copy() + train_data_frame = TEST_DF_1.copy() with pytest.raises(RuntimeError): - TabularData.from_df( + TabularData.from_data_frame( numerical_cols=None, categorical_cols=None, target_col="label", - train_df=train_df, + train_data_frame=train_data_frame, num_workers=0, batch_size=1, ) diff --git a/tests/tabular/test_data_model_integration.py b/tests/tabular/test_data_model_integration.py index 0f4cb26b62..6dcec9b6a8 100644 --- a/tests/tabular/test_data_model_integration.py +++ b/tests/tabular/test_data_model_integration.py @@ -28,16 +28,16 @@ def test_classification(tmpdir): - train_df = TEST_DF_1.copy() - val_df = TEST_DF_1.copy() - test_df = TEST_DF_1.copy() - data = TabularData.from_df( + train_data_frame = TEST_DF_1.copy() + val_data_frame = TEST_DF_1.copy() + test_data_frame = TEST_DF_1.copy() + data = TabularData.from_data_frame( categorical_cols=["category"], numerical_cols=["scalar_a", "scalar_b"], target_col="label", - train_df=train_df, - val_df=val_df, - test_df=test_df, + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, num_workers=0, batch_size=2, ) From db1cdf1f0c508c79289e8561606d6dae910bbf26 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 15:36:34 +0100 Subject: [PATCH 70/78] feedback --- flash/data/auto_dataset.py | 17 +++++++---- flash/data/process.py | 46 ++++++++++------------------- flash/vision/classification/data.py | 6 ++++ flash/vision/detection/data.py | 6 ++++ 4 files changed, 40 insertions(+), 35 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 1755ea7c7b..55a6352e72 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -29,10 +29,17 @@ class BaseAutoDataset(Generic[DATA_TYPE]): DATASET_KEY = "dataset" - """ - This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. - ``load_data`` will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` - is provided and ``load_sample`` within ``__getitem__``. + """This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. ``load_data`` + will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` is provided and + ``load_sample`` within ``__getitem__``. + + Args: + + data: The output of a call to :meth:`~flash.data.data_source.load_data`. + + data_source: The :class:`~flash.data.data_source.DataSource` which has the ``load_sample`` method. + + running_stage: The current running stage. """ def __init__( @@ -56,7 +63,7 @@ def running_stage(self) -> RunningStage: @running_stage.setter def running_stage(self, running_stage: RunningStage) -> None: from flash.data.data_pipeline import DataPipeline # noqa F811 - from flash.data.data_source import DataSource # Hack to avoid circular import TODO: something better than this + from flash.data.data_source import DataSource # noqa F811 # TODO: something better than this self._running_stage = running_stage diff --git a/flash/data/process.py b/flash/data/process.py index e86095b12e..2c6e04a2f4 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -216,9 +216,10 @@ def __init__( super().__init__() # resolve the default transforms - train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( - train_transform, val_transform, test_transform, predict_transform - ) + train_transform = train_transform or self.default_train_transforms() + val_transform = val_transform or self.default_val_transforms() + test_transform = test_transform or self.default_test_transforms() + predict_transform = predict_transform or self.default_predict_transforms() # used to keep track of provided transforms self._train_collate_in_worker_from_transform: Optional[bool] = None @@ -241,6 +242,18 @@ def __init__( self._default_data_source = default_data_source self._callbacks: List[FlashCallback] = [] + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + pass + + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + pass + + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + pass + + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + pass + def _save_to_state_dict(self, destination, prefix, keep_vars): preprocess_state_dict = self.get_state_dict() if not isinstance(preprocess_state_dict, Dict): @@ -253,33 +266,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): self._ddp_params_and_buffers_to_ignore = ['preprocess.state_dict'] return super()._save_to_state_dict(destination, prefix, keep_vars) - def default_train_transforms(self) -> Optional[Dict[str, Callable]]: - return None - - def default_val_transforms(self) -> Optional[Dict[str, Callable]]: - return None - - def _resolve_transforms( - self, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', - ): - if not train_transform or train_transform == 'default': - train_transform = self.default_train_transforms() - - if not val_transform or val_transform == 'default': - val_transform = self.default_val_transforms() - - if not test_transform or test_transform == 'default': - test_transform = self.default_val_transforms() - - if not predict_transform or predict_transform == 'default': - predict_transform = self.default_val_transforms() - - return train_transform, val_transform, test_transform, predict_transform - def _check_transforms(self, transform: Optional[Dict[str, Callable]], stage: RunningStage) -> Optional[Dict[str, Callable]]: if transform is None: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 230c0ad417..6b061f1947 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -87,6 +87,12 @@ def default_train_transforms(self) -> Optional[Dict[str, Callable]]: def default_val_transforms(self) -> Optional[Dict[str, Callable]]: return default_val_transforms(self.image_size) + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) + + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) + class ImageClassificationData(DataModule): """Data module for image classification tasks.""" diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 590e7f9a83..676a602f67 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -131,6 +131,12 @@ def default_train_transforms(self) -> Optional[Dict[str, Callable]]: def default_val_transforms(self) -> Optional[Dict[str, Callable]]: return default_transforms() + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms() + + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms() + class ObjectDetectionData(DataModule): From 88cbc6509973389bb3a2ad2d61f3f4cea1f4cec0 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 16:08:44 +0100 Subject: [PATCH 71/78] Updates --- flash/core/model.py | 8 ++--- flash/data/data_module.py | 56 ++++++++++++++++++++----------- flash/data/data_pipeline.py | 6 ++-- flash/data/process.py | 4 +-- flash/text/classification/data.py | 29 +++++++--------- flash/text/seq2seq/core/data.py | 24 ++++++------- flash/text/seq2seq/core/model.py | 2 +- 7 files changed, 70 insertions(+), 59 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 0b9657ec1a..4115711672 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -158,7 +158,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: def predict( self, x: Any, - data_source: str = "default", + data_source: Optional[str] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ @@ -176,7 +176,7 @@ def predict( data_pipeline = self.build_data_pipeline(data_source, data_pipeline) - x = [x for x in data_pipeline._data_source.generate_dataset(x, running_stage)] + x = [x for x in data_pipeline.data_source.generate_dataset(x, running_stage)] x = data_pipeline.worker_preprocessor(running_stage)(x) # switch to self.device when #7188 merge in Lightning x = self.transfer_batch_to_device(x, next(self.parameters()).device) @@ -282,7 +282,7 @@ def build_data_pipeline( # Datamodule if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None: - old_data_source = getattr(self.datamodule.data_pipeline, '_data_source', None) + old_data_source = getattr(self.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) @@ -290,7 +290,7 @@ def build_data_pipeline( elif self.trainer is not None and hasattr( self.trainer, 'datamodule' ) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None: - old_data_source = getattr(self.trainer.datamodule.data_pipeline, '_data_source', None) + old_data_source = getattr(self.trainer.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index a47c4a148c..47129ce811 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -340,10 +340,10 @@ def from_data_source( val_data: Any = None, test_data: Any = None, predict_data: Any = None, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', + train_transform: Optional[Dict] = None, + val_transform: Optional[Dict] = None, + test_transform: Optional[Dict] = None, + predict_transform: Optional[Dict] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -388,10 +388,10 @@ def from_folders( val_folder: Optional[str] = None, test_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', + train_transform: Optional[Dict] = None, + val_transform: Optional[Dict] = None, + test_transform: Optional[Dict] = None, + predict_transform: Optional[Dict] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -427,10 +427,10 @@ def from_files( test_files: Optional[Sequence[str]] = None, test_targets: Optional[Sequence[Any]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', + train_transform: Optional[Dict] = None, + val_transform: Optional[Dict] = None, + test_transform: Optional[Dict] = None, + predict_transform: Optional[Dict] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -466,10 +466,10 @@ def from_tensors( test_data: Optional[Collection[torch.Tensor]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', + train_transform: Optional[Dict] = None, + val_transform: Optional[Dict] = None, + test_transform: Optional[Dict] = None, + predict_transform: Optional[Dict] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -505,10 +505,10 @@ def from_numpy( test_data: Optional[Collection[np.ndarray]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[np.ndarray]] = None, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', + train_transform: Optional[Dict] = None, + val_transform: Optional[Dict] = None, + test_transform: Optional[Dict] = None, + predict_transform: Optional[Dict] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -543,6 +543,10 @@ def from_json( val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, + train_transform: Optional[Dict] = None, + val_transform: Optional[Dict] = None, + test_transform: Optional[Dict] = None, + predict_transform: Optional[Dict] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -556,6 +560,10 @@ def from_json( (val_file, input_fields, target_fields), (test_file, input_fields, target_fields), (predict_file, input_fields, target_fields), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, data_fetcher=data_fetcher, preprocess=preprocess, val_split=val_split, @@ -573,6 +581,10 @@ def from_csv( val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, + train_transform: Optional[Dict] = None, + val_transform: Optional[Dict] = None, + test_transform: Optional[Dict] = None, + predict_transform: Optional[Dict] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -586,6 +598,10 @@ def from_csv( (val_file, input_fields, target_fields), (test_file, input_fields, target_fields), (predict_file, input_fields, target_fields), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, data_fetcher=data_fetcher, preprocess=preprocess, val_split=val_split, diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 2db825affa..07ab9bab50 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -95,7 +95,7 @@ def __init__( postprocess: Optional[Postprocess] = None, serializer: Optional[Serializer] = None, ) -> None: - self._data_source = data_source + self.data_source = data_source self._preprocess_pipeline = preprocess or DefaultPreprocess() self._postprocess_pipeline = postprocess or Postprocess() @@ -110,8 +110,8 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> give a warning.""" data_pipeline_state = data_pipeline_state or DataPipelineState() data_pipeline_state._initialized = False - if self._data_source is not None: - self._data_source.attach_data_pipeline_state(data_pipeline_state) + if self.data_source is not None: + self.data_source.attach_data_pipeline_state(data_pipeline_state) self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._serializer.attach_data_pipeline_state(data_pipeline_state) diff --git a/flash/data/process.py b/flash/data/process.py index 2c6e04a2f4..e279010f0e 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -403,8 +403,8 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: """ return self.current_transform(batch) - def data_source_of_name(self, data_source_name: str) -> Optional[DATA_SOURCE_TYPE]: - if data_source_name == "default": + def data_source_of_name(self, data_source_name: Optional[str]) -> Optional[DATA_SOURCE_TYPE]: + if data_source_name is None: data_source_name = self._default_data_source data_sources = self._data_sources if data_source_name in data_sources: diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 145ce65454..b7662c9838 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -21,7 +21,6 @@ from transformers.modeling_outputs import SequenceClassifierOutput from flash.data.auto_dataset import AutoDataset -from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.data_source import DataSource, DefaultDataSources, LabelsState from flash.data.process import Postprocess, Preprocess @@ -29,10 +28,10 @@ class TextDataSource(DataSource): - def __init__(self, tokenizer, max_length: int = 128): + def __init__(self, backbone: str, max_length: int = 128): super().__init__() - self.tokenizer = tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) self.max_length = max_length def _tokenize_fn( @@ -52,8 +51,8 @@ def _transform_label(self, label_to_class_mapping: Dict[str, int], target: str, class TextFileDataSource(TextDataSource): - def __init__(self, filetype: str, tokenizer, max_length: int = 128): - super().__init__(tokenizer, max_length=max_length) + def __init__(self, filetype: str, backbone: str, max_length: int = 128): + super().__init__(backbone, max_length=max_length) self.filetype = filetype @@ -110,20 +109,20 @@ def predict_load_data(self, data: Any, dataset: AutoDataset): class TextCSVDataSource(TextFileDataSource): - def __init__(self, tokenizer, max_length: int = 128): - super().__init__("csv", tokenizer, max_length=max_length) + def __init__(self, backbone: str, max_length: int = 128): + super().__init__("csv", backbone, max_length=max_length) class TextJSONDataSource(TextFileDataSource): - def __init__(self, tokenizer, max_length: int = 128): - super().__init__("json", tokenizer, max_length=max_length) + def __init__(self, backbone: str, max_length: int = 128): + super().__init__("json", backbone, max_length=max_length) class TextSentencesDataSource(TextDataSource): - def __init__(self, tokenizer, max_length: int = 128): - super().__init__(tokenizer, max_length=max_length) + def __init__(self, backbone: str, max_length: int = 128): + super().__init__(backbone, max_length=max_length) def load_data( self, @@ -150,17 +149,15 @@ def __init__( self.backbone = backbone self.max_length = max_length - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.CSV: TextCSVDataSource(self.tokenizer, max_length=max_length), - DefaultDataSources.JSON: TextJSONDataSource(self.tokenizer, max_length=max_length), - "sentences": TextSentencesDataSource(self.tokenizer, max_length=max_length), + DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length), + DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length), + "sentences": TextSentencesDataSource(self.backbone, max_length=max_length), }, default_data_source="sentences", ) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 3fbfecd6df..d8de69acfc 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -30,14 +30,14 @@ class Seq2SeqDataSource(DataSource): def __init__( self, - tokenizer, + backbone: str, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length' ): super().__init__() - self.tokenizer = tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) self.max_source_length = max_source_length self.max_target_length = max_target_length self.padding = padding @@ -69,12 +69,12 @@ class Seq2SeqFileDataSource(Seq2SeqDataSource): def __init__( self, filetype: str, - tokenizer, + backbone: str, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', ): - super().__init__(tokenizer, max_source_length, max_target_length, padding) + super().__init__(backbone, max_source_length, max_target_length, padding) self.filetype = filetype @@ -113,14 +113,14 @@ class Seq2SeqCSVDataSource(Seq2SeqFileDataSource): def __init__( self, - tokenizer, + backbone: str, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', ): super().__init__( "csv", - tokenizer, + backbone, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, @@ -131,14 +131,14 @@ class Seq2SeqJSONDataSource(Seq2SeqFileDataSource): def __init__( self, - tokenizer, + backbone: str, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', ): super().__init__( "json", - tokenizer, + backbone, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, @@ -176,8 +176,6 @@ def __init__( self.max_source_length = max_source_length self.padding = padding - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - super().__init__( train_transform=train_transform, val_transform=val_transform, @@ -185,19 +183,19 @@ def __init__( predict_transform=predict_transform, data_sources={ DefaultDataSources.CSV: Seq2SeqCSVDataSource( - self.tokenizer, + self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, ), DefaultDataSources.JSON: Seq2SeqJSONDataSource( - self.tokenizer, + self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, ), "sentences": Seq2SeqSentencesDataSource( - self.tokenizer, + self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 8971584bde..3caec065ca 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -120,7 +120,7 @@ def _initialize_model_specific_parameters(self): @property def tokenizer(self) -> PreTrainedTokenizerBase: - return self.data_pipeline._preprocess_pipeline.tokenizer + return self.data_pipeline.data_source.tokenizer def tokenize_labels(self, labels: Tensor) -> List[str]: label_str = self.tokenizer.batch_decode(labels, skip_special_tokens=True) From 4ee1dd4588108df43c0289f71d3a1192eeaf190d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 16:17:18 +0100 Subject: [PATCH 72/78] Fixes --- flash/data/data_module.py | 56 ++++++++++++++--------------- flash/vision/classification/data.py | 8 ++--- flash/vision/detection/data.py | 14 ++++---- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 47129ce811..f64c25284a 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -340,10 +340,10 @@ def from_data_source( val_data: Any = None, test_data: Any = None, predict_data: Any = None, - train_transform: Optional[Dict] = None, - val_transform: Optional[Dict] = None, - test_transform: Optional[Dict] = None, - predict_transform: Optional[Dict] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -388,10 +388,10 @@ def from_folders( val_folder: Optional[str] = None, test_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: Optional[Dict] = None, - val_transform: Optional[Dict] = None, - test_transform: Optional[Dict] = None, - predict_transform: Optional[Dict] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -427,10 +427,10 @@ def from_files( test_files: Optional[Sequence[str]] = None, test_targets: Optional[Sequence[Any]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: Optional[Dict] = None, - val_transform: Optional[Dict] = None, - test_transform: Optional[Dict] = None, - predict_transform: Optional[Dict] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -466,10 +466,10 @@ def from_tensors( test_data: Optional[Collection[torch.Tensor]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, - train_transform: Optional[Dict] = None, - val_transform: Optional[Dict] = None, - test_transform: Optional[Dict] = None, - predict_transform: Optional[Dict] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -505,10 +505,10 @@ def from_numpy( test_data: Optional[Collection[np.ndarray]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[np.ndarray]] = None, - train_transform: Optional[Dict] = None, - val_transform: Optional[Dict] = None, - test_transform: Optional[Dict] = None, - predict_transform: Optional[Dict] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -543,10 +543,10 @@ def from_json( val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, - train_transform: Optional[Dict] = None, - val_transform: Optional[Dict] = None, - test_transform: Optional[Dict] = None, - predict_transform: Optional[Dict] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -581,10 +581,10 @@ def from_csv( val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, - train_transform: Optional[Dict] = None, - val_transform: Optional[Dict] = None, - test_transform: Optional[Dict] = None, - predict_transform: Optional[Dict] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 6b061f1947..05ab1f5f29 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -39,10 +39,10 @@ class ImageClassificationPreprocess(Preprocess): def __init__( self, - train_transform: Optional[Union[Dict[str, Callable]]] = None, - val_transform: Optional[Union[Dict[str, Callable]]] = None, - test_transform: Optional[Union[Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Dict[str, Callable]]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), ): self.image_size = image_size diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 676a602f67..8c26a17d08 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -93,10 +93,10 @@ class ObjectDetectionPreprocess(Preprocess): def __init__( self, - train_transform: Optional[Union[Dict[str, Callable]]] = None, - val_transform: Optional[Union[Dict[str, Callable]]] = None, - test_transform: Optional[Union[Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Dict[str, Callable]]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, ): super().__init__( train_transform=train_transform, @@ -147,13 +147,13 @@ def from_coco( cls, train_folder: Optional[str] = None, train_ann_file: Optional[str] = None, - train_transform: Optional[Dict[str, Module]] = None, + train_transform: Optional[Dict[str, Callable]] = None, val_folder: Optional[str] = None, val_ann_file: Optional[str] = None, - val_transform: Optional[Dict[str, Module]] = None, + val_transform: Optional[Dict[str, Callable]] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, - test_transform: Optional[Dict[str, Module]] = None, + test_transform: Optional[Dict[str, Callable]] = None, batch_size: int = 4, num_workers: Optional[int] = None, preprocess: Preprocess = None, From ce3fcf27c3c0a4c48b9396bb44d682621c82181c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 16:20:17 +0100 Subject: [PATCH 73/78] Fixes --- flash/data/data_source.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 40402f7f49..850f885078 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import typing from dataclasses import dataclass from inspect import signature from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union @@ -104,7 +105,7 @@ def generate_dataset( if not is_none: from flash.data.data_pipeline import DataPipeline - mock_dataset = MockDataset() + mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): load_data: Callable = getattr( self, DataPipeline._resolve_function_hierarchy( From ed22b109c7d205012524671011542299e1204c0f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 16:24:59 +0100 Subject: [PATCH 74/78] revert --- flash/data/callback.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flash/data/callback.py b/flash/data/callback.py index 8609868f7c..1221046a31 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -156,8 +156,7 @@ def __init__(self, enabled: bool = False): def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None: if self.enabled: store = self.batches[_STAGES_PREFIX[running_stage]] - if fn_name not in store: - store[fn_name] = [] + store.setdefault(fn_name, []) store[fn_name].append(data) def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: From f453d0375f7b8d0e898ce437cbb9a23a71e38552 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 16:34:43 +0100 Subject: [PATCH 75/78] Updates --- flash/data/auto_dataset.py | 4 ++-- flash/data/data_source.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 55a6352e72..165852c927 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from inspect import signature -from typing import Any, Callable, Generic, Iterable, Sequence, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Generic, Iterable, Optional, Sequence, TYPE_CHECKING, TypeVar from pytorch_lightning.trainer.states import RunningStage from torch.utils.data import Dataset, IterableDataset @@ -69,7 +69,7 @@ def running_stage(self, running_stage: RunningStage) -> None: self._load_sample_context = CurrentRunningStageFuncContext(self.running_stage, "load_sample", self.data_source) - self.load_sample: Callable = getattr( + self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self.data_source, DataPipeline._resolve_function_hierarchy( 'load_sample', diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 850f885078..bd9c174ab5 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -107,7 +107,7 @@ def generate_dataset( mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): - load_data: Callable = getattr( + load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self, DataPipeline._resolve_function_hierarchy( "load_data", self, From 1ae8c56f631ec75645ff92b14af926c5035d5a9e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 16:51:37 +0100 Subject: [PATCH 76/78] Fixes --- flash/core/model.py | 2 +- flash/video/classification/data.py | 16 ++++++---------- flash/vision/classification/data.py | 8 +------- flash/vision/detection/data.py | 7 +------ 4 files changed, 9 insertions(+), 24 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 4115711672..bc1ed7bad9 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -322,7 +322,7 @@ def build_data_pipeline( data_source = data_source or old_data_source - if isinstance(data_source, str): + if data_source is None or isinstance(data_source, str): if preprocess is None: data_source = DataSource() # TODO: warn the user that we are not using the specified data source else: diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 0b035f786a..f3cda86cb6 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -158,16 +158,12 @@ def __init__( def get_state_dict(self) -> Dict[str, Any]: return { - 'train_transform': self._train_transform, - 'val_transform': self._val_transform, - 'test_transform': self._test_transform, - 'predict_transform': self._predict_transform, - 'clip_sampler': self.clip_sampler, - 'clip_duration': self.clip_duration, - 'clip_sampler_kwargs': self.clip_sampler_kwargs, - 'video_sampler': self.video_sampler, - 'decode_audio': self.decode_audio, - 'decoder': self.decoder, + "clip_sampler": self.clip_sampler, + "clip_duration": self.clip_duration, + "clip_sampler_kwargs": self.clip_sampler_kwargs, + "video_sampler": self.video_sampler, + "decode_audio": self.decode_audio, + "decoder": self.decoder, } @classmethod diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 05ab1f5f29..5463a8a11e 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -61,13 +61,7 @@ def __init__( ) def get_state_dict(self) -> Dict[str, Any]: - return { - "train_transform": self._train_transform, - "val_transform": self._val_transform, - "test_transform": self._test_transform, - "predict_transform": self._predict_transform, - "image_size": self.image_size - } + return {"image_size": self.image_size} @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 8c26a17d08..28d406c4e3 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -114,12 +114,7 @@ def collate(self, samples: Any) -> Any: return {key: [sample[key] for sample in samples] for key in samples[0]} def get_state_dict(self) -> Dict[str, Any]: - return { - "train_transform": self._train_transform, - "val_transform": self._val_transform, - "test_transform": self._test_transform, - "predict_transform": self._predict_transform, - } + return {} @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): From 1088022b1532f48a7e808f58206a81e40bc8a289 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 17:06:05 +0100 Subject: [PATCH 77/78] Fixes --- flash/core/model.py | 4 ++-- flash/data/auto_dataset.py | 4 ++-- flash/data/data_source.py | 8 ++++---- flash/data/process.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index bc1ed7bad9..6c453ae0bf 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -174,7 +174,7 @@ def predict( """ running_stage = RunningStage.PREDICTING - data_pipeline = self.build_data_pipeline(data_source, data_pipeline) + data_pipeline = self.build_data_pipeline(data_source or "default", data_pipeline) x = [x for x in data_pipeline.data_source.generate_dataset(x, running_stage)] x = data_pipeline.worker_preprocessor(running_stage)(x) @@ -322,7 +322,7 @@ def build_data_pipeline( data_source = data_source or old_data_source - if data_source is None or isinstance(data_source, str): + if isinstance(data_source, str): if preprocess is None: data_source = DataSource() # TODO: warn the user that we are not using the specified data source else: diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 165852c927..191385900d 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -92,7 +92,7 @@ def _call_load_sample(self, sample: Any) -> Any: return sample -class AutoDataset(BaseAutoDataset[Sequence[Any]], Dataset): +class AutoDataset(BaseAutoDataset[Sequence], Dataset): def __getitem__(self, index: int) -> Any: return self._call_load_sample(self.data[index]) @@ -101,7 +101,7 @@ def __len__(self) -> int: return len(self.data) -class IterableAutoDataset(BaseAutoDataset[Iterable[Any]], IterableDataset): +class IterableAutoDataset(BaseAutoDataset[Iterable], IterableDataset): def __iter__(self): self.data_iter = iter(self.data) diff --git a/flash/data/data_source.py b/flash/data/data_source.py index bd9c174ab5..e637eab923 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -157,7 +157,7 @@ def __hash__(self) -> int: class SequenceDataSource( Generic[SEQUENCE_DATA_TYPE], - DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence[Any]]]], + DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]]], ): def __init__(self, labels: Optional[Sequence[str]] = None): @@ -170,7 +170,7 @@ def __init__(self, labels: Optional[Sequence[str]] = None): def load_data( self, - data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence[Any]]], + data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], dataset: Optional[Any] = None, ) -> Sequence[Mapping[str, Any]]: # TODO: Bring back the code to work out how many classes there are @@ -219,7 +219,7 @@ def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool: def load_data(self, data: Union[str, Tuple[List[str], List[Any]]], - dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: if self.isdir(data): classes, class_to_idx = self.find_classes(data) if not classes: @@ -241,7 +241,7 @@ def load_data(self, def predict_load_data(self, data: Union[str, List[str]], - dataset: Optional[Any] = None) -> Iterable[Mapping[str, Any]]: + dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: if self.isdir(data): data = [os.path.join(data, file) for file in os.listdir(data)] diff --git a/flash/data/process.py b/flash/data/process.py index e279010f0e..2c6e04a2f4 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -403,8 +403,8 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: """ return self.current_transform(batch) - def data_source_of_name(self, data_source_name: Optional[str]) -> Optional[DATA_SOURCE_TYPE]: - if data_source_name is None: + def data_source_of_name(self, data_source_name: str) -> Optional[DATA_SOURCE_TYPE]: + if data_source_name == "default": data_source_name = self._default_data_source data_sources = self._data_sources if data_source_name in data_sources: From 9032be4824a37bb987e6f540557c891ae00338e2 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 17:39:50 +0100 Subject: [PATCH 78/78] Fixes --- flash/data/process.py | 63 ++++++++++++++--------- flash/tabular/classification/data/data.py | 1 + flash/text/classification/data.py | 1 + flash/text/seq2seq/core/data.py | 1 + flash/video/classification/data.py | 1 + flash/vision/classification/data.py | 6 ++- flash/vision/detection/data.py | 6 ++- 7 files changed, 52 insertions(+), 27 deletions(-) diff --git a/flash/data/process.py b/flash/data/process.py index 2c6e04a2f4..050847dfa0 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -216,10 +216,10 @@ def __init__( super().__init__() # resolve the default transforms - train_transform = train_transform or self.default_train_transforms() - val_transform = val_transform or self.default_val_transforms() - test_transform = test_transform or self.default_test_transforms() - predict_transform = predict_transform or self.default_predict_transforms() + train_transform = train_transform or self.default_train_transforms + val_transform = val_transform or self.default_val_transforms + test_transform = test_transform or self.default_test_transforms + predict_transform = predict_transform or self.default_predict_transforms # used to keep track of provided transforms self._train_collate_in_worker_from_transform: Optional[bool] = None @@ -228,31 +228,44 @@ def __init__( self._test_collate_in_worker_from_transform: Optional[bool] = None # store the transform before conversion to modules. - self._train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) - self._val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) - self._test_transform = self._check_transforms(test_transform, RunningStage.TESTING) - self._predict_transform = self._check_transforms(predict_transform, RunningStage.PREDICTING) + self.train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) + self.val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) + self.test_transform = self._check_transforms(test_transform, RunningStage.TESTING) + self.predict_transform = self._check_transforms(predict_transform, RunningStage.PREDICTING) - self.train_transform = convert_to_modules(self._train_transform) - self.val_transform = convert_to_modules(self._val_transform) - self.test_transform = convert_to_modules(self._test_transform) - self.predict_transform = convert_to_modules(self._predict_transform) + self._train_transform = convert_to_modules(self.train_transform) + self._val_transform = convert_to_modules(self.val_transform) + self._test_transform = convert_to_modules(self.test_transform) + self._predict_transform = convert_to_modules(self.predict_transform) self._data_sources = data_sources self._default_data_source = default_data_source self._callbacks: List[FlashCallback] = [] + @property def default_train_transforms(self) -> Optional[Dict[str, Callable]]: - pass + return None + @property def default_val_transforms(self) -> Optional[Dict[str, Callable]]: - pass + return None + @property def default_test_transforms(self) -> Optional[Dict[str, Callable]]: - pass + return None + @property def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: - pass + return None + + @property + def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: + return { + "train_transform": self.train_transform, + "val_transform": self.val_transform, + "test_transform": self.test_transform, + "predict_transform": self.predict_transform, + } def _save_to_state_dict(self, destination, prefix, keep_vars): preprocess_state_dict = self.get_state_dict() @@ -327,14 +340,14 @@ def _get_transform(self, transform: Dict[str, Callable]) -> Callable: @property def current_transform(self) -> Callable: - if self.training and self.train_transform: - return self._get_transform(self.train_transform) - elif self.validating and self.val_transform: - return self._get_transform(self.val_transform) - elif self.testing and self.test_transform: - return self._get_transform(self.test_transform) - elif self.predicting and self.predict_transform: - return self._get_transform(self.predict_transform) + if self.training and self._train_transform: + return self._get_transform(self._train_transform) + elif self.validating and self._val_transform: + return self._get_transform(self._val_transform) + elif self.testing and self._test_transform: + return self._get_transform(self._test_transform) + elif self.predicting and self._predict_transform: + return self._get_transform(self._predict_transform) else: return self._identity @@ -434,7 +447,7 @@ def __init__( ) def get_state_dict(self) -> Dict[str, Any]: - return {} + return {**self.transforms} @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index 2969405275..1b4ad6b9bd 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -149,6 +149,7 @@ def __init__( def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: return { + **self.transforms, "cat_cols": self.cat_cols, "num_cols": self.num_cols, "target_col": self.target_col, diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index b7662c9838..7f867fb76c 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -164,6 +164,7 @@ def __init__( def get_state_dict(self) -> Dict[str, Any]: return { + **self.transforms, "backbone": self.backbone, "max_length": self.max_length, } diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index d8de69acfc..f7968ee4a7 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -206,6 +206,7 @@ def __init__( def get_state_dict(self) -> Dict[str, Any]: return { + **self.transforms, "backbone": self.backbone, "max_source_length": self.max_source_length, "max_target_length": self.max_target_length, diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index f3cda86cb6..5aefd5d14a 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -158,6 +158,7 @@ def __init__( def get_state_dict(self) -> Dict[str, Any]: return { + **self.transforms, "clip_sampler": self.clip_sampler, "clip_duration": self.clip_duration, "clip_sampler_kwargs": self.clip_sampler_kwargs, diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 5463a8a11e..928605b244 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -61,7 +61,7 @@ def __init__( ) def get_state_dict(self) -> Dict[str, Any]: - return {"image_size": self.image_size} + return {**self.transforms, "image_size": self.image_size} @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): @@ -75,15 +75,19 @@ def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: sample[key] = sample[key].squeeze(0) return default_collate(samples) + @property def default_train_transforms(self) -> Optional[Dict[str, Callable]]: return default_train_transforms(self.image_size) + @property def default_val_transforms(self) -> Optional[Dict[str, Callable]]: return default_val_transforms(self.image_size) + @property def default_test_transforms(self) -> Optional[Dict[str, Callable]]: return default_val_transforms(self.image_size) + @property def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: return default_val_transforms(self.image_size) diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 28d406c4e3..528a74a99d 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -114,21 +114,25 @@ def collate(self, samples: Any) -> Any: return {key: [sample[key] for sample in samples] for key in samples[0]} def get_state_dict(self) -> Dict[str, Any]: - return {} + return {**self.transforms} @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) + @property def default_train_transforms(self) -> Optional[Dict[str, Callable]]: return default_transforms() + @property def default_val_transforms(self) -> Optional[Dict[str, Callable]]: return default_transforms() + @property def default_test_transforms(self) -> Optional[Dict[str, Callable]]: return default_transforms() + @property def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: return default_transforms()