diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index 45126f20de..c457be23c6 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -8,7 +8,7 @@ Image Classification ******** The task ******** -The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that desecribes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc. For example, we can train the image classifier task on images of ants and it will learn to predict the probability that an image contains an ant. +The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that describes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc. For example, we can train the image classifier task on images of ants and it will learn to predict the probability that an image contains an ant. ------ diff --git a/flash/core/classification.py b/flash/core/classification.py index f82de91c5f..970466dbf4 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -14,6 +14,7 @@ from typing import Any import torch +import torch.nn.functional as F from flash.core.model import Task from flash.data.process import Postprocess @@ -29,3 +30,6 @@ class ClassificationTask(Task): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs) + + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + return F.softmax(x, -1) diff --git a/flash/core/model.py b/flash/core/model.py index 68d17173fe..eeaf268b75 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -96,6 +96,7 @@ def step(self, batch: Any, batch_idx: int) -> Any: output = {"y_hat": y_hat} losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} logs = {} + y_hat = self.to_metrics_format(y_hat) for name, metric in self.metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) @@ -111,6 +112,9 @@ def step(self, batch: Any, batch_idx: int) -> Any: output["y"] = y return output + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + return x + def forward(self, x: Any) -> Any: return self.model(x) @@ -172,10 +176,10 @@ def configure_finetune_callback(self) -> List[Callback]: @staticmethod def _resolve( - old_preprocess: Optional[Preprocess], - old_postprocess: Optional[Postprocess], - new_preprocess: Optional[Preprocess], - new_postprocess: Optional[Postprocess], + old_preprocess: Optional[Preprocess], + old_postprocess: Optional[Postprocess], + new_preprocess: Optional[Preprocess], + new_postprocess: Optional[Postprocess], ) -> Tuple[Optional[Preprocess], Optional[Postprocess]]: """Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise. @@ -308,3 +312,10 @@ def available_backbones(cls) -> List[str]: if registry is None: return [] return registry.available_keys() + + @classmethod + def available_models(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "models", None) + if registry is None: + return [] + return registry.available_keys() diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 498b67a33d..2ba6dd92f4 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager from inspect import signature -from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING +from typing import Any, Callable, Iterable, Iterator, Optional, TYPE_CHECKING +import torch from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.warning_utils import rank_zero_warn -from torch.utils.data import Dataset +from torch.utils.data import Dataset, IterableDataset from flash.data.callback import ControlFlow from flash.data.process import Preprocess @@ -27,13 +27,13 @@ from flash.data.data_pipeline import DataPipeline -class AutoDataset(Dataset): +class BaseAutoDataset: DATASET_KEY = "dataset" """ This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. ``load_data`` will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` - is provided and ``load_sample`` within ``__getitem__`` function. + is provided and ``load_sample`` within ``__getitem__``. """ def __init__( @@ -122,10 +122,19 @@ def _setup(self, stage: Optional[RunningStage]) -> None: "The load_data function of the Autogenerated Dataset changed. " "This is not expected! Preloading Data again to ensure compatibility. This may take some time." ) - with self._load_data_context: - self.preprocessed_data = self._call_load_data(self.data) + self.setup() self._load_data_called = True + def setup(self): + raise NotImplementedError + + +class AutoDataset(BaseAutoDataset, Dataset): + + def setup(self): + with self._load_data_context: + self.preprocessed_data = self._call_load_data(self.data) + def __getitem__(self, index: int) -> Any: if not self.load_sample and not self.load_data: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") @@ -141,3 +150,29 @@ def __len__(self) -> int: if not self.load_sample and not self.load_data: raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.") return len(self.preprocessed_data) + + +class IterableAutoDataset(BaseAutoDataset, IterableDataset): + + def setup(self): + with self._load_data_context: + self.dataset = self._call_load_data(self.data) + self.dataset_iter = None + + def __iter__(self): + self.dataset_iter = iter(self.dataset) + return self + + def __next__(self) -> Any: + if not self.load_sample and not self.load_data: + raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") + + data = next(self.dataset_iter) + + if self.load_sample: + with self._load_sample_context: + data: Any = self._call_load_sample(data) + if self.control_flow_callback: + self.control_flow_callback.on_load_sample(data, self.running_stage) + return data + return data diff --git a/flash/data/data_module.py b/flash/data/data_module.py index dfcf213662..c8986ad024 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader, Dataset from torch.utils.data.dataset import IterableDataset, Subset -from flash.data.auto_dataset import AutoDataset +from flash.data.auto_dataset import BaseAutoDataset, IterableAutoDataset from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -212,15 +212,16 @@ def set_running_stages(self): self.set_dataset_attribute(self._predict_ds, 'running_stage', RunningStage.PREDICTING) def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> Optional[Callable]: - if isinstance(dataset, AutoDataset): + if isinstance(dataset, BaseAutoDataset): return self.data_pipeline.worker_preprocessor(running_stage) def _train_dataloader(self) -> DataLoader: train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds + shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset)) return DataLoader( train_ds, batch_size=self.batch_size, - shuffle=True, + shuffle=shuffle, num_workers=self.num_workers, pin_memory=True, drop_last=True, @@ -249,10 +250,13 @@ def _test_dataloader(self) -> DataLoader: def _predict_dataloader(self) -> DataLoader: predict_ds: Dataset = self._predict_ds() if isinstance(self._predict_ds, Callable) else self._predict_ds + if isinstance(predict_ds, IterableAutoDataset): + batch_size = self.batch_size + else: + batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) return DataLoader( predict_ds, - batch_size=min(self.batch_size, - len(predict_ds) if len(predict_ds) > 0 else 1), + batch_size=batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) @@ -263,6 +267,13 @@ def generate_auto_dataset(self, *args, **kwargs): return None return self.data_pipeline._generate_auto_dataset(*args, **kwargs) + @property + def num_classes(self) -> Optional[int]: + return ( + getattr(self.train_dataset, "num_classes", None) or getattr(self.val_dataset, "num_classes", None) + or getattr(self.test_dataset, "num_classes", None) + ) + @property def preprocess(self) -> Preprocess: return self._preprocess or self.preprocess_cls() @@ -292,9 +303,10 @@ def autogenerate_dataset( whole_data_load_fn: Optional[Callable] = None, per_sample_load_fn: Optional[Callable] = None, data_pipeline: Optional[DataPipeline] = None, - ) -> AutoDataset: + use_iterable_auto_dataset: bool = False, + ) -> BaseAutoDataset: """ - This function is used to generate an ``AutoDataset`` from a ``DataPipeline`` if provided + This function is used to generate an ``BaseAutoDataset`` from a ``DataPipeline`` if provided or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly """ @@ -309,7 +321,11 @@ def autogenerate_dataset( cls.preprocess_cls, DataPipeline._resolve_function_hierarchy('load_sample', cls.preprocess_cls, running_stage, Preprocess) ) - return AutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage) + if use_iterable_auto_dataset: + return IterableAutoDataset( + data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage + ) + return BaseAutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage) @staticmethod def train_val_test_split( @@ -379,15 +395,27 @@ def _generate_dataset_if_possible( running_stage: RunningStage, whole_data_load_fn: Optional[Callable] = None, per_sample_load_fn: Optional[Callable] = None, - data_pipeline: Optional[DataPipeline] = None - ) -> Optional[AutoDataset]: + data_pipeline: Optional[DataPipeline] = None, + use_iterable_auto_dataset: bool = False, + ) -> Optional[BaseAutoDataset]: if data is None: return if data_pipeline: - return data_pipeline._generate_auto_dataset(data, running_stage=running_stage) + return data_pipeline._generate_auto_dataset( + data, + running_stage=running_stage, + use_iterable_auto_dataset=use_iterable_auto_dataset, + ) - return cls.autogenerate_dataset(data, running_stage, whole_data_load_fn, per_sample_load_fn, data_pipeline) + return cls.autogenerate_dataset( + data, + running_stage, + whole_data_load_fn, + per_sample_load_fn, + data_pipeline, + use_iterable_auto_dataset=use_iterable_auto_dataset, + ) @classmethod def from_load_data_inputs( @@ -398,6 +426,7 @@ def from_load_data_inputs( predict_load_data_input: Optional[Any] = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, + use_iterable_auto_dataset: bool = False, **kwargs, ) -> 'DataModule': """ @@ -429,16 +458,28 @@ def from_load_data_inputs( data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline) train_dataset = cls._generate_dataset_if_possible( - train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline + train_load_data_input, + running_stage=RunningStage.TRAINING, + data_pipeline=data_pipeline, + use_iterable_auto_dataset=use_iterable_auto_dataset, ) val_dataset = cls._generate_dataset_if_possible( - val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline + val_load_data_input, + running_stage=RunningStage.VALIDATING, + data_pipeline=data_pipeline, + use_iterable_auto_dataset=use_iterable_auto_dataset, ) test_dataset = cls._generate_dataset_if_possible( - test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline + test_load_data_input, + running_stage=RunningStage.TESTING, + data_pipeline=data_pipeline, + use_iterable_auto_dataset=use_iterable_auto_dataset, ) predict_dataset = cls._generate_dataset_if_possible( - predict_load_data_input, running_stage=RunningStage.PREDICTING, data_pipeline=data_pipeline + predict_load_data_input, + running_stage=RunningStage.PREDICTING, + data_pipeline=data_pipeline, + use_iterable_auto_dataset=use_iterable_auto_dataset, ) datamodule = cls( train_dataset=train_dataset, diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index e0c9940eef..2125b79909 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -20,10 +20,10 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import imports from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data import DataLoader, IterableDataset from torch.utils.data._utils.collate import default_collate, default_convert -from torch.utils.data.dataloader import DataLoader -from flash.data.auto_dataset import AutoDataset +from flash.data.auto_dataset import AutoDataset, IterableAutoDataset from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential from flash.data.process import Postprocess, Preprocess from flash.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX @@ -144,52 +144,61 @@ def _resolve_function_hierarchy( return function_name + def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: + if on_device: + return self._identity, collate + else: + return collate, self._identity + def _create_collate_preprocessors( self, stage: RunningStage, collate_fn: Optional[Callable] = None, ) -> Tuple[_PreProcessor, _PreProcessor]: + original_collate_fn = collate_fn + if collate_fn is None: collate_fn = default_collate preprocess: Preprocess = self._preprocess_pipeline + prefix: str = _STAGES_PREFIX[stage] func_names: Dict[str, str] = { k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS } - if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]): + if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=prefix): collate_fn: Callable = getattr(preprocess, func_names["collate"]) per_batch_transform_overriden: bool = self._is_overriden_recursive( - "per_batch_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] + "per_batch_transform", preprocess, Preprocess, prefix=prefix ) per_sample_transform_on_device_overriden: bool = self._is_overriden_recursive( - "per_sample_transform_on_device", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] + "per_sample_transform_on_device", preprocess, Preprocess, prefix=prefix ) - skip_mutual_check: bool = getattr(preprocess, "skip_mutual_check", False) + collate_in_worker_from_transform: Optional[bool] = getattr( + preprocess, f"_{prefix}_collate_in_worker_from_transform" + ) - if (not skip_mutual_check and per_batch_transform_overriden and per_sample_transform_on_device_overriden): + if ( + collate_in_worker_from_transform is None and per_batch_transform_overriden + and per_sample_transform_on_device_overriden + ): raise MisconfigurationException( f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` ' f'are mutual exclusive for stage {stage}' ) - elif per_batch_transform_overriden: - worker_collate_fn = collate_fn - device_collate_fn = self._identity - - elif per_sample_transform_on_device_overriden: - worker_collate_fn = self._identity - device_collate_fn = collate_fn - + if isinstance(collate_in_worker_from_transform, bool): + worker_collate_fn, device_collate_fn = self._make_collates(not collate_in_worker_from_transform, collate_fn) else: - worker_collate_fn = collate_fn - device_collate_fn = self._identity + worker_collate_fn, device_collate_fn = self._make_collates( + per_sample_transform_on_device_overriden, collate_fn + ) worker_collate_fn = worker_collate_fn.collate_fn if isinstance( worker_collate_fn, _PreProcessor @@ -284,9 +293,6 @@ def _attach_preprocess_to_model( for stage in stages: - if stage == RunningStage.PREDICTING: - pass - loader_name = f'{_STAGES_PREFIX[stage]}_dataloader' dataloader, whole_attr_name = self._get_dataloader(model, loader_name) @@ -297,7 +303,7 @@ def _attach_preprocess_to_model( if isinstance(dataloader, (_PatchDataLoader, Callable)): dataloader = dataloader() - if not dataloader: + if dataloader is None: continue if isinstance(dataloader, Sequence): @@ -315,6 +321,9 @@ def _attach_preprocess_to_model( stage=stage, collate_fn=dl_args['collate_fn'] ) + if isinstance(dl_args["dataset"], IterableDataset): + del dl_args["sampler"] + # don't have to reinstantiate loader if just rewrapping devices (happens during detach) if not device_transform_only: del dl_args["batch_sampler"] @@ -427,8 +436,13 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} if isinstance(dl_args['collate_fn'], _PreProcessor): - dl_args['collate_fn'] = dl_args['collate_fn']._original_collate_fn + dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn + + if isinstance(dl_args["dataset"], IterableAutoDataset): + del dl_args['sampler'] + del dl_args["batch_sampler"] + loader = type(loader)(**dl_args) dataloader[idx] = loader @@ -458,7 +472,14 @@ def fn(): return fn - def _generate_auto_dataset(self, data: Union[Iterable, Any], running_stage: RunningStage = None) -> AutoDataset: + def _generate_auto_dataset( + self, + data: Union[Iterable, Any], + running_stage: RunningStage = None, + use_iterable_auto_dataset: bool = False + ) -> Union[AutoDataset, IterableAutoDataset]: + if use_iterable_auto_dataset: + return IterableAutoDataset(data, data_pipeline=self, running_stage=running_stage) return AutoDataset(data=data, data_pipeline=self, running_stage=running_stage) def to_dataloader( diff --git a/flash/data/process.py b/flash/data/process.py index 670f906ed0..542ae8f3dc 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -17,13 +17,14 @@ import torch from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from torch.nn import Module from torch.utils.data._utils.collate import default_collate from flash.data.batch import default_uncollate from flash.data.callback import FlashCallback -from flash.data.utils import convert_to_modules +from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules class Properties: @@ -100,7 +101,7 @@ class PreprocessState: pass -class Preprocess(Properties, torch.nn.Module): +class Preprocess(Properties, Module): """ The :class:`~flash.data.process.Preprocess` encapsulates all the data processing and loading logic that should run before the data is passed to the model. @@ -254,37 +255,88 @@ def load_data(cls, path_to_data: str) -> Iterable: def __init__( self, - train_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, - val_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, - test_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, ): super().__init__() - self.train_transform = convert_to_modules(train_transform) - self.val_transform = convert_to_modules(val_transform) - self.test_transform = convert_to_modules(test_transform) - self.predict_transform = convert_to_modules(predict_transform) + + # used to keep track of provided transforms + self._train_collate_in_worker_from_transform: Optional[bool] = None + self._val_collate_in_worker_from_transform: Optional[bool] = None + self._predict_collate_in_worker_from_transform: Optional[bool] = None + self._test_collate_in_worker_from_transform: Optional[bool] = None + + self.train_transform = convert_to_modules(self._check_transforms(train_transform, RunningStage.TRAINING)) + self.val_transform = convert_to_modules(self._check_transforms(val_transform, RunningStage.VALIDATING)) + self.test_transform = convert_to_modules(self._check_transforms(test_transform, RunningStage.TESTING)) + self.predict_transform = convert_to_modules(self._check_transforms(predict_transform, RunningStage.PREDICTING)) if not hasattr(self, "_skip_mutual_check"): self._skip_mutual_check = False self._callbacks: List[FlashCallback] = [] - @property - def skip_mutual_check(self) -> bool: - return self._skip_mutual_check + # todo (tchaton) Add a warning if a transform is provided, but the hook hasn't been overriden ! + def _check_transforms(self, transform: Optional[Dict[str, Callable]], + stage: RunningStage) -> Optional[Dict[str, Callable]]: + if transform is None: + return transform + + if not isinstance(transform, Dict): + raise MisconfigurationException( + "Transform should be a dict. " + f"Here are the available keys for your transforms: {_PREPROCESS_FUNCS}." + ) + + keys_diff = set(transform.keys()).difference(_PREPROCESS_FUNCS) + + if len(keys_diff) > 0: + raise MisconfigurationException( + f"{stage}_transform contains {keys_diff}. Only {_PREPROCESS_FUNCS} keys are supported." + ) + + is_per_batch_transform_in = "per_batch_transform" in transform + is_per_sample_transform_on_device_in = "per_sample_transform_on_device" in transform - @skip_mutual_check.setter - def skip_mutual_check(self, skip_mutual_check: bool) -> None: - self._skip_mutual_check = skip_mutual_check + if is_per_batch_transform_in and is_per_sample_transform_on_device_in: + raise MisconfigurationException( + f'{transform}: `per_batch_transform` and `per_sample_transform_on_device` ' + f'are mutually exclusive.' + ) - def _identify(self, x: Any) -> Any: + collate_in_worker: Optional[bool] = None + + if is_per_batch_transform_in or (not is_per_batch_transform_in and not is_per_sample_transform_on_device_in): + collate_in_worker = True + + elif is_per_sample_transform_on_device_in: + collate_in_worker = False + + setattr(self, f"_{_STAGES_PREFIX[stage]}_collate_in_worker_from_transform", collate_in_worker) + return transform + + @staticmethod + def _identity(x: Any) -> Any: return x + # todo (tchaton): Remove when merged. https://github.com/PyTorchLightning/pytorch-lightning/pull/7056 + def tmp_wrap(self, transform) -> Callable: + if "on_device" in self.current_fn: + + def fn(batch: Any): + if isinstance(batch, list) and len(batch) == 1 and isinstance(batch[0], dict): + return [transform(batch[0])] + return transform(batch) + + return fn + return transform + def _get_transform(self, transform: Dict[str, Callable]) -> Callable: if self.current_fn in transform: - return transform[self.current_fn] - return self._identify + return self.tmp_wrap(transform[self.current_fn]) + return self._identity @property def current_transform(self) -> Callable: @@ -297,7 +349,7 @@ def current_transform(self) -> Callable: elif self.predicting and self.predict_transform: return self._get_transform(self.predict_transform) else: - return self._identify + return self._identity @classmethod def from_state(cls, state: PreprocessState) -> 'Preprocess': @@ -388,7 +440,7 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: return batch -class Postprocess(Properties, torch.nn.Module): +class Postprocess(Properties, Module): def __init__(self, save_path: Optional[str] = None): super().__init__() diff --git a/flash/data/utils.py b/flash/data/utils.py index f3c28612c3..48bac51a93 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -14,13 +14,14 @@ import os.path import zipfile -from typing import Any, Callable, Dict, Iterable, Mapping, Set, Type +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Set, Type import requests import torch from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.apply_func import apply_to_collection from torch import Tensor +from torch.nn import Module from tqdm.auto import tqdm as tq _STAGES_PREFIX = { @@ -119,10 +120,13 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: Usage: download_file('http://web4host.net/5MB.zip') """ + if url == "NEED_TO_BE_CREATED": + raise NotImplementedError + if not os.path.exists(path): os.makedirs(path) local_filename = os.path.join(path, url.split('/')[-1]) - r = requests.get(url, stream=True) + r = requests.get(url, stream=True, verify=False) file_size = int(r.headers['Content-Length']) if 'Content-Length' in r.headers else 0 chunk_size = 1024 num_bars = int(file_size / chunk_size) @@ -177,7 +181,7 @@ def __str__(self) -> str: return f"{self.__class__.__name__}({str(self.func)})" -def convert_to_modules(transforms: Dict): +def convert_to_modules(transforms: Optional[Dict[str, Callable]]): if transforms is None or isinstance(transforms, torch.nn.Module): return transforms diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 5627c1b620..158ee52b29 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -39,9 +39,6 @@ class ImageClassificationPreprocess(Preprocess): - # this assignement is used to skip the assert that `per_batch_transform` and `per_sample_transform_on_device` - # are mutually exclusive on the DataPipeline internals - _skip_mutual_check = True to_tensor = torchvision.transforms.ToTensor() @staticmethod @@ -160,7 +157,7 @@ def pre_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) def to_tensor_transform(self, sample: Any) -> Any: - if self.current_transform == self._identify: + if self.current_transform == self._identity: if isinstance(sample, (list, tuple)): source, target = sample if isinstance(source, torch.Tensor): @@ -176,9 +173,6 @@ def to_tensor_transform(self, sample: Any) -> Any: def post_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) - # todo: (tchaton) `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive - # `skip_mutual_check` is used to skip the checks as the information are provided from the transforms directly - # Need to properly set the `collate` depending on user provided transforms def per_batch_transform(self, sample: Any) -> Any: return self.common_step(sample) @@ -247,7 +241,7 @@ def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[ if "per_batch_transform" in transform and "per_sample_transform_on_device" in transform: raise MisconfigurationException( f'{transform}: `per_batch_transform` and `per_sample_transform_on_device` ' - f'are mutual exclusive.' + f'are mutually exclusive.' ) return transform @@ -313,7 +307,7 @@ def instantiate_preprocess( val_transform: Dict[str, Union[nn.Module, Callable]], test_transform: Dict[str, Union[nn.Module, Callable]], predict_transform: Dict[str, Union[nn.Module, Callable]], - preprocess_cls: Type[Preprocess] = None + preprocess_cls: Type[Preprocess] = None, ) -> Preprocess: """ This function is used to instantiate ImageClassificationData preprocess object. diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index b605215841..d08ac6cdef 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type import torch from PIL import Image from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor, tensor from torch._six import container_abcs +from torch.nn import Module from torch.utils.data._utils.collate import default_collate from torchvision import transforms as T @@ -130,9 +131,6 @@ def _has_valid_annotation(annot: List): return dataset -_default_transform = T.ToTensor() - - class ObjectDetectionPreprocess(Preprocess): to_tensor = T.ToTensor() @@ -163,6 +161,9 @@ def pre_tensor_transform(self, samples: Any) -> Any: return outputs raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.") + def to_tensor_transform(self, sample) -> Any: + return self.to_tensor(sample[0]), sample[1] + def predict_to_tensor_transform(self, sample) -> Any: return self.to_tensor(sample[0]) @@ -182,35 +183,40 @@ class ObjectDetectionData(DataModule): @classmethod def instantiate_preprocess( cls, - train_transform: Optional[Callable], - val_transform: Optional[Callable], - preprocess_cls: Type[Preprocess] = None + train_transform: Optional[Dict[str, Module]] = None, + val_transform: Optional[Dict[str, Module]] = None, + test_transform: Optional[Dict[str, Module]] = None, + predict_transform: Optional[Dict[str, Module]] = None, + preprocess_cls: Type[Preprocess] = None, ) -> Preprocess: preprocess_cls = preprocess_cls or cls.preprocess_cls - return preprocess_cls(train_transform, val_transform) + return preprocess_cls(train_transform, val_transform, test_transform, predict_transform) @classmethod def from_coco( cls, train_folder: Optional[str] = None, train_ann_file: Optional[str] = None, - train_transform: Optional[Callable] = _default_transform, + train_transform: Optional[Dict[str, Module]] = None, val_folder: Optional[str] = None, val_ann_file: Optional[str] = None, - val_transform: Optional[Callable] = _default_transform, + val_transform: Optional[Dict[str, Module]] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, - test_transform: Optional[Callable] = _default_transform, + test_transform: Optional[Dict[str, Module]] = None, + predict_transform: Optional[Dict[str, Module]] = None, batch_size: int = 4, num_workers: Optional[int] = None, preprocess_cls: Type[Preprocess] = None, **kwargs ): - preprocess = cls.instantiate_preprocess(train_transform, val_transform, preprocess_cls=preprocess_cls) + preprocess = cls.instantiate_preprocess( + train_transform, val_transform, predict_transform, predict_transform, preprocess_cls=preprocess_cls + ) - datamodule = cls.from_load_data_inputs( + return cls.from_load_data_inputs( train_load_data_input=(train_folder, train_ann_file, train_transform), val_load_data_input=(val_folder, val_ann_file, val_transform) if val_folder else None, test_load_data_input=(test_folder, test_ann_file, test_transform) if test_folder else None, @@ -219,5 +225,3 @@ def from_coco( preprocess=preprocess, **kwargs ) - datamodule.num_classes = datamodule._train_ds.num_classes - return datamodule diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 6c0aa1ed3e..de19d76d05 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -33,7 +33,7 @@ # 3.a Optional: Register a custom backbone # This is useful to create new backbone and make them accessible from `ImageClassifier` -@ImageClassifier.backbones(name="username/resnet18") +@ImageClassifier.backbones(name="resnet18") def fn_resnet(pretrained: bool = True): model = torchvision.models.resnet18(pretrained) # remove the last two layers & turn it into a Sequential model @@ -47,7 +47,7 @@ def fn_resnet(pretrained: bool = True): print(ImageClassifier.available_backbones()) # 4. Build the model -model = ImageClassifier(backbone="username/resnet18", num_classes=datamodule.num_classes) +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) # 5. Create the trainer. trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) diff --git a/flash_examples/generic_task.py b/flash_examples/generic_task.py index ec92fcb90e..e82d6d0ee2 100644 --- a/flash_examples/generic_task.py +++ b/flash_examples/generic_task.py @@ -18,7 +18,7 @@ from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms -from flash import ClassificationTask +from flash.core.classification import ClassificationTask from flash.data.utils import download_data _PATH_ROOT = os.path.dirname(os.path.dirname(__file__)) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index d1682c053c..9d15247cf2 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -27,7 +27,7 @@ from torch.utils.data._utils.collate import default_collate from flash.core import Task -from flash.data.auto_dataset import AutoDataset +from flash.data.auto_dataset import AutoDataset, IterableAutoDataset from flash.data.batch import _PostProcessor, _PreProcessor from flash.data.data_module import DataModule from flash.data.data_pipeline import _StageOrchestrator, DataPipeline @@ -213,7 +213,7 @@ def test_per_batch_transform_on_device(self, *_, **__): assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform - assert val_worker_preprocessor.collate_fn.func == data_pipeline._identity + assert val_worker_preprocessor.collate_fn.func == DataPipeline._identity assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = test_worker_preprocessor.per_sample_transform @@ -594,7 +594,8 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: assert self.validating assert self.current_fn == "per_batch_transform_on_device" self.val_per_batch_transform_on_device_called = True - batch = batch[0] + if isinstance(batch, list): + batch = batch[0] assert torch.equal(batch["a"], tensor([0, 1])) assert torch.equal(batch["b"], tensor([1, 2])) return [False] @@ -648,6 +649,8 @@ def training_step(self, batch, batch_idx): assert batch is None def validation_step(self, batch, batch_idx): + if isinstance(batch, list): + batch = batch[0] assert batch is False def test_step(self, batch, batch_idx): @@ -824,3 +827,96 @@ def from_folders( ) trainer.fit(model, datamodule=datamodule) trainer.test(model) + + +def test_preprocess_transforms(tmpdir): + """ + This test makes sure that when a preprocess is being provided transforms as dictionaries, + checking is done properly, and collate_in_worker_from_transform is properly extracted. + """ + + with pytest.raises(MisconfigurationException, match="Transform should be a dict."): + Preprocess(train_transform="choco") + + with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): + Preprocess(train_transform={"choco": None}) + + preprocess = Preprocess(train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) + # keep is None + assert preprocess._train_collate_in_worker_from_transform is True + assert preprocess._val_collate_in_worker_from_transform is None + assert preprocess._test_collate_in_worker_from_transform is None + assert preprocess._predict_collate_in_worker_from_transform is None + + with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): + preprocess = Preprocess( + train_transform={ + "per_batch_transform": torch.nn.Linear(1, 1), + "per_sample_transform_on_device": torch.nn.Linear(1, 1) + } + ) + + preprocess = Preprocess( + train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, + predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)} + ) + # keep is None + assert preprocess._train_collate_in_worker_from_transform is True + assert preprocess._val_collate_in_worker_from_transform is None + assert preprocess._test_collate_in_worker_from_transform is None + assert preprocess._predict_collate_in_worker_from_transform is False + + train_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.TRAINING) + val_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.VALIDATING) + test_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.TESTING) + predict_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.PREDICTING) + + assert train_preprocessor.collate_fn.func == default_collate + assert val_preprocessor.collate_fn.func == default_collate + assert test_preprocessor.collate_fn.func == default_collate + assert predict_preprocessor.collate_fn.func == DataPipeline._identity + + class CustomPreprocess(Preprocess): + + def per_sample_transform_on_device(self, sample: Any) -> Any: + return super().per_sample_transform_on_device(sample) + + def per_batch_transform(self, batch: Any) -> Any: + return super().per_batch_transform(batch) + + preprocess = CustomPreprocess( + train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, + predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)} + ) + # keep is None + assert preprocess._train_collate_in_worker_from_transform is True + assert preprocess._val_collate_in_worker_from_transform is None + assert preprocess._test_collate_in_worker_from_transform is None + assert preprocess._predict_collate_in_worker_from_transform is False + + data_pipeline = DataPipeline(preprocess) + + train_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) + with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): + val_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING) + with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): + test_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING) + predict_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING) + + assert train_preprocessor.collate_fn.func == default_collate + assert predict_preprocessor.collate_fn.func == DataPipeline._identity + + +def test_iterable_auto_dataset(tmpdir): + + class CustomPreprocess(Preprocess): + + def load_sample(self, index: int) -> Dict[str, int]: + return {"index": index} + + data_pipeline = DataPipeline(CustomPreprocess()) + + ds = IterableAutoDataset(range(10), running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline) + + for index, v in enumerate(ds): + assert v == {"index": index} diff --git a/tests/vision/classification/test_data_model_integration.py b/tests/vision/classification/test_data_model_integration.py index 1181df70ee..fb00d93b0f 100644 --- a/tests/vision/classification/test_data_model_integration.py +++ b/tests/vision/classification/test_data_model_integration.py @@ -42,7 +42,7 @@ def test_classification(tmpdir): data = ImageClassificationData.from_filepaths( train_filepaths=[tmpdir / "a", tmpdir / "b"], train_labels=[0, 1], - train_transform={"per_sample_per_batch_transform": lambda x: x}, + train_transform={"per_batch_transform": lambda x: x}, num_workers=0, batch_size=2, )