diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 77a38e560f..f862045337 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -90,12 +90,10 @@ _______________________ :nosignatures: :template: classtemplate.rst - ~flash.core.data.io.input_transform.BaseInputTransform - ~flash.core.data.io.input_transform.DefaultInputTransform ~flash.core.data.process.DeserializerMapping ~flash.core.data.process.Deserializer ~flash.core.data.io.output_transform.OutputTransform - ~flash.core.data.io.input_transform.InputTransform + ~flash.core.data.input_transform.InputTransform flash.core.data.properties __________________________ @@ -144,9 +142,6 @@ _____________________ :nosignatures: :template: classtemplate.rst - ~flash.core.data.utils.CurrentFuncContext - ~flash.core.data.utils.CurrentRunningStageContext - ~flash.core.data.utils.CurrentRunningStageFuncContext ~flash.core.data.utils.FuncModule .. autosummary:: diff --git a/docs/source/api/flash.rst b/docs/source/api/flash.rst index cc8775572e..dd18597221 100644 --- a/docs/source/api/flash.rst +++ b/docs/source/api/flash.rst @@ -12,6 +12,6 @@ flash ~flash.core.data.callback.FlashCallback ~flash.core.data.io.output_transform.OutputTransform ~flash.core.data.io.output.Output - ~flash.core.data.io.input_transform.InputTransform + ~flash.core.data.input_transform.InputTransform ~flash.core.model.Task ~flash.core.trainer.Trainer diff --git a/flash/__init__.py b/flash/__init__.py index 7073f68180..d2f29d5920 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -21,8 +21,8 @@ from flash.core.data.callback import FlashCallback from flash.core.data.data_module import DataModule + from flash.core.data.input_transform import InputTransform from flash.core.data.io.input import DataKeys, Input - from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Serializer diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index acbd1c41a3..bc24f5d717 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -27,10 +27,10 @@ ) from flash.audio.classification.input_transform import AudioClassificationInputTransform from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.data.utilities.paths import PATH_TYPE from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index de9d673a5e..2e5de11b4d 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -22,10 +22,10 @@ SpeechRecognitionPathsInput, ) from flash.audio.speech_recognition.output_transform import SpeechRecognitionOutputTransform +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index b8302f3185..fb3d196d35 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -17,43 +17,14 @@ from torch import Tensor from flash.core.data.callback import ControlFlow, FlashCallback -from flash.core.data.utils import convert_to_modules, CurrentFuncContext, CurrentRunningStageContext +from flash.core.data.utils import convert_to_modules from flash.core.utilities.stages import RunningStage if TYPE_CHECKING: - from flash.core.data.io.input_transform import InputTransform + from flash.core.data.input_transform import InputTransform from flash.core.data.process import Deserializer -class _DeserializeProcessor(torch.nn.Module): - def __init__( - self, - deserializer: "Deserializer", - input_transform: "InputTransform", - per_sample_transform: Callable, - callbacks: Optional[List[FlashCallback]] = None, - ): - super().__init__() - self.input_transform = input_transform - self.callback = ControlFlow(callbacks or []) - self.deserializer = convert_to_modules(deserializer) - self.per_sample_transform = convert_to_modules(per_sample_transform) - - self._current_stage_context = CurrentRunningStageContext(RunningStage.PREDICTING, input_transform, reset=False) - self._per_sample_transform_context = CurrentFuncContext("per_sample_transform", input_transform) - - def forward(self, sample: str): - - sample = self.deserializer(sample) - - with self._current_stage_context: - with self._per_sample_transform_context: - sample = self.per_sample_transform(sample) - self.callback.on_per_sample_transform(sample, RunningStage.PREDICTING) - - return sample - - class _DeserializeProcessorV2(torch.nn.Module): def __init__( self, diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index e247d60687..8249bbae05 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -12,25 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader, Dataset -from torch.utils.data.dataset import IterableDataset, Subset +from torch.utils.data._utils.collate import default_collate +from torch.utils.data.dataset import IterableDataset from torch.utils.data.sampler import Sampler import flash from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher -from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.io.input import Input, InputBase, IterableInput -from flash.core.data.io.input_transform import DefaultInputTransform, InputTransform +from flash.core.data.data_pipeline import DataPipeline, DataPipelineState +from flash.core.data.input_transform import InputTransform +from flash.core.data.io.input import DataKeys, Input, InputBase, IterableInput from flash.core.data.io.output_transform import OutputTransform from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX +from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FIFTYONE_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -40,24 +42,26 @@ SampleCollection = None +class DatasetInput(Input): + """The ``DatasetInput`` implements default behaviours for data sources which expect the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset` + """ + + def load_sample(self, sample: Any) -> Dict[str, Any]: + if isinstance(sample, tuple) and len(sample) == 2: + return {DataKeys.INPUT: sample[0], DataKeys.TARGET: sample[1]} + return {DataKeys.INPUT: sample} + + class DataModule(pl.LightningDataModule): """A basic DataModule class for all Flash tasks. This class includes references to a - :class:`~flash.core.data.io.input.Input`, :class:`~flash.core.data.io.input_transform.InputTransform`, - :class:`~flash.core.data.io.output_transform.OutputTransform`, and a - :class:`~flash.core.data.callback.BaseDataFetcher`. + :class:`~flash.core.data.datasets.Input` and a :class:`~flash.core.data.callback.BaseDataFetcher`. Args: - train_dataset: Dataset for training. Defaults to None. - val_dataset: Dataset for validating model performance during training. Defaults to None. - test_dataset: Dataset to test model performance. Defaults to None. - predict_dataset: Dataset for predicting. Defaults to None. - input: The :class:`~flash.core.data.io.input.Input` that was used to create the datasets. - input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to use when constructing the - :class:`~flash.core.data.data_pipeline.DataPipeline`. If ``None``, a - :class:`~flash.core.data.io.input_transform.DefaultInputTransform` will be used. - output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to use when constructing the - :class:`~flash.core.data.data_pipeline.DataPipeline`. If ``None``, a plain - :class:`~flash.core.data.io.output_transform.OutputTransform` will be used. + train_input: Input dataset for training. Defaults to None. + val_input: Input dataset for validating model performance during training. Defaults to None. + test_input: Input dataset to test model performance. Defaults to None. + predict_input: Input dataset for predicting. Defaults to None. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to attach to the :class:`~flash.core.data.io.input_transform.InputTransform`. If ``None``, the output from :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used. @@ -71,57 +75,71 @@ class DataModule(pl.LightningDataModule): Will be passed to the DataLoader for the training dataset. Defaults to None. """ - input_transform_cls = DefaultInputTransform + input_transform_cls = InputTransform output_transform_cls = OutputTransform + input_transforms_registry: Optional[FlashRegistry] = None def __init__( self, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - test_dataset: Optional[Dataset] = None, - predict_dataset: Optional[Dataset] = None, - input: Optional[Input] = None, - input_transform: Optional[InputTransform] = None, - output_transform: Optional[OutputTransform] = None, + train_input: Optional[Input] = None, + val_input: Optional[Input] = None, + test_input: Optional[Input] = None, + predict_input: Optional[Input] = None, data_fetcher: Optional[BaseDataFetcher] = None, val_split: Optional[float] = None, - batch_size: int = 4, + batch_size: Optional[int] = None, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, + pin_memory: bool = True, + persistent_workers: bool = True, + output_transform: Optional[OutputTransform] = None, ) -> None: - super().__init__() + if not batch_size: + raise MisconfigurationException("The `batch_size` should be provided to the DataModule on instantiation.") if flash._IS_TESTING and torch.cuda.is_available(): batch_size = 16 - self._train_ds = train_dataset - self._val_ds = val_dataset - self._test_ds = test_dataset - self._predict_ds = predict_dataset - - if self._train_ds and (val_split is not None and not self._val_ds): - self._train_ds, self._val_ds = self._split_train_val(self._train_ds, val_split) - - self._input: Input = input - self._input_transform: Optional[InputTransform] = input_transform + self._input_transform: Optional[OutputTransform] = None self._output_transform: Optional[OutputTransform] = output_transform self._viz: Optional[BaseVisualization] = None + + self._train_input = train_input + self._val_input = val_input + self._test_input = test_input + self._predict_input = predict_input + self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() - # TODO: InputTransform can change - self.data_fetcher.attach_to_input_transform(self.input_transform) + self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input) + self._val_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._val_input) + self._test_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._test_input) + self._predict_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._predict_input) - if self._train_ds: + self._train_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._train_input) + self._val_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._val_input) + self._test_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._test_input) + self._predict_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._predict_input) + + if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0: + raise MisconfigurationException( + "A `val_dataset` was provided with `val_split`. Please, choose one or the other." + ) + + if self._train_input is not None and (val_split is not None and self._val_input is None): + self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) + + if self._train_input: self.train_dataloader = self._train_dataloader - if self._val_ds: + if self._val_input: self.val_dataloader = self._val_dataloader - if self._test_ds: + if self._test_input: self.test_dataloader = self._test_dataloader - if self._predict_ds: + if self._predict_input: self.predict_dataloader = self._predict_dataloader self.batch_size = batch_size @@ -129,175 +147,75 @@ def __init__( if num_workers is None: num_workers = 0 self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 + self.pin_memory = pin_memory self.sampler = sampler - self.set_running_stages() - - # Share state between input objects (this will be available in ``load_sample`` but not in ``load_data``) - data_pipeline = self.data_pipeline - data_pipeline.initialize() + super().__init__(self) @property - def train_dataset(self) -> Optional[Dataset]: + def train_dataset(self) -> Optional[Input]: """This property returns the train dataset.""" - return self._train_ds + return self._train_input @property - def val_dataset(self) -> Optional[Dataset]: + def val_dataset(self) -> Optional[Input]: """This property returns the validation dataset.""" - return self._val_ds + return self._val_input @property - def test_dataset(self) -> Optional[Dataset]: + def test_dataset(self) -> Optional[Input]: """This property returns the test dataset.""" - return self._test_ds + return self._test_input @property - def predict_dataset(self) -> Optional[Dataset]: + def predict_dataset(self) -> Optional[Input]: """This property returns the predict dataset.""" - return self._predict_ds - - @property - def viz(self) -> BaseVisualization: - return self._viz or DataModule.configure_data_fetcher() - - @viz.setter - def viz(self, viz: BaseVisualization) -> None: - self._viz = viz + return self._predict_input - @staticmethod - def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - """This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`. - - Override with your custom one. - """ - return BaseDataFetcher() - - @property - def data_fetcher(self) -> BaseDataFetcher: - """This property returns the data fetcher.""" - return self._data_fetcher or DataModule.configure_data_fetcher() - - @data_fetcher.setter - def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None: - self._data_fetcher = data_fetcher - - def _reset_iterator(self, stage: str) -> Iterable[Any]: - iter_name = f"_{stage}_iter" - # num_workers has to be set to 0 to work properly - num_workers = self.num_workers - self.num_workers = 0 - dataloader_fn = getattr(self, f"{stage}_dataloader") - iterator = iter(dataloader_fn()) - self.num_workers = num_workers - setattr(self, iter_name, iterator) - return iterator - - def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool = True) -> None: - """This function is used to handle transforms profiling for batch visualization.""" - # don't show in CI - if os.getenv("FLASH_TESTING", "0") == "1": + def _resolve_transform(self, ds: Optional[Input]) -> Optional[InputTransform]: + if not isinstance(ds, Input): return None - iter_name = f"_{stage}_iter" - - if not hasattr(self, iter_name): - self._reset_iterator(stage) - - # list of functions to visualise - if isinstance(func_names, str): - func_names = [func_names] - - iter_dataloader = getattr(self, iter_name) - with self.data_fetcher.enable(): - if reset: - self.data_fetcher.batches[stage] = {} - try: - _ = next(iter_dataloader) - except StopIteration: - iter_dataloader = self._reset_iterator(stage) - _ = next(iter_dataloader) - data_fetcher: BaseVisualization = self.data_fetcher - data_fetcher._show(stage, func_names) - if reset: - self.data_fetcher.batches[stage] = {} - - def show_train_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: - """This function is used to visualize a batch from the train dataloader.""" - stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING] - self._show_batch(stage_name, hooks_names, reset=reset) - - def show_val_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: - """This function is used to visualize a batch from the validation dataloader.""" - stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING] - self._show_batch(stage_name, hooks_names, reset=reset) - - def show_test_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: - """This function is used to visualize a batch from the test dataloader.""" - stage_name: str = _STAGES_PREFIX[RunningStage.TESTING] - self._show_batch(stage_name, hooks_names, reset=reset) - - def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: - """This function is used to visualize a batch from the predict dataloader.""" - stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING] - self._show_batch(stage_name, hooks_names, reset=reset) - - @staticmethod - def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: - if isinstance(dataset, Subset): - return getattr(dataset.dataset, attr_name, default) - - return getattr(dataset, attr_name, default) - - @staticmethod - def set_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, value: Any) -> None: - if isinstance(dataset, Subset): - dataset = dataset.dataset - if isinstance(dataset, (Dataset, IterableDataset)): - setattr(dataset, attr_name, value) - - def set_running_stages(self): - if self._train_ds: - self.set_dataset_attribute(self._train_ds, "running_stage", RunningStage.TRAINING) + return ds.transform - if self._val_ds: - self.set_dataset_attribute(self._val_ds, "running_stage", RunningStage.VALIDATING) - - if self._test_ds: - self.set_dataset_attribute(self._test_ds, "running_stage", RunningStage.TESTING) - - if self._predict_ds: - self.set_dataset_attribute(self._predict_ds, "running_stage", RunningStage.PREDICTING) + def _resolve_dataloader_collate_fn(self, ds: Optional[Input]) -> Optional[Callable]: + if not ds: + return None + if isinstance(ds.transform, InputTransform): + return ds._create_dataloader_collate_fn([self.data_fetcher]) + return default_collate - def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> Optional[Callable]: - if isinstance(dataset, (SplitDataset, InputBase)): - return self.data_pipeline.worker_input_transform_processor(running_stage) + def _resolve_on_after_batch_transfer_fn(self, ds: Optional[Input]) -> Optional[Callable]: + if not ds: + return None + if isinstance(ds.transform, InputTransform): + return ds._create_on_after_batch_transfer_fn([self.data_fetcher]) def _train_dataloader(self) -> DataLoader: - """Configure the train dataloader of the datamodule.""" - train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds + train_ds: Input = self._train_input + collate_fn = self._train_dataloader_collate_fn shuffle: bool = False - collate_fn = self._resolve_collate_fn(train_ds, RunningStage.TRAINING) - if isinstance(train_ds, IterableInput): + if isinstance(train_ds, IterableDataset): drop_last = False else: drop_last = len(train_ds) > self.batch_size - pin_memory = True - persistent_workers = self.num_workers > 0 if self.sampler is None: sampler = None - shuffle = not isinstance(train_ds, (IterableDataset, IterableInput)) + shuffle = not isinstance(train_ds, IterableDataset) else: sampler = self.sampler(train_ds) if isinstance(getattr(self, "trainer", None), pl.Trainer): + if isinstance(self.trainer.lightning_module, flash.Task): + self.connect(self.trainer.lightning_module) return self.trainer.lightning_module.process_train_dataset( train_ds, trainer=self.trainer, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=pin_memory, + pin_memory=self.pin_memory, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, @@ -310,26 +228,25 @@ def _train_dataloader(self) -> DataLoader: shuffle=shuffle, sampler=sampler, num_workers=self.num_workers, - pin_memory=pin_memory, + pin_memory=self.pin_memory, drop_last=drop_last, collate_fn=collate_fn, - persistent_workers=persistent_workers, + persistent_workers=self.persistent_workers, ) def _val_dataloader(self) -> DataLoader: - """Configure the validation dataloader of the datamodule.""" - val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds - collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING) - pin_memory = True - persistent_workers = self.num_workers > 0 + val_ds: Input = self._val_input + collate_fn = self._val_dataloader_collate_fn if isinstance(getattr(self, "trainer", None), pl.Trainer): + if isinstance(self.trainer.lightning_module, flash.Task): + self.connect(self.trainer.lightning_module) return self.trainer.lightning_module.process_val_dataset( val_ds, trainer=self.trainer, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=pin_memory, + pin_memory=self.pin_memory, collate_fn=collate_fn, ) @@ -337,25 +254,24 @@ def _val_dataloader(self) -> DataLoader: val_ds, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=pin_memory, + pin_memory=self.pin_memory, collate_fn=collate_fn, - persistent_workers=persistent_workers, + persistent_workers=self.persistent_workers, ) def _test_dataloader(self) -> DataLoader: - """Configure the test dataloader of the datamodule.""" - test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds - collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING) - pin_memory = True - persistent_workers = False + test_ds: Input = self._test_input + collate_fn = self._test_dataloader_collate_fn if isinstance(getattr(self, "trainer", None), pl.Trainer): + if isinstance(self.trainer.lightning_module, flash.Task): + self.connect(self.trainer.lightning_module) return self.trainer.lightning_module.process_test_dataset( test_ds, trainer=self.trainer, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=pin_memory, + pin_memory=self.pin_memory, collate_fn=collate_fn, ) @@ -363,30 +279,28 @@ def _test_dataloader(self) -> DataLoader: test_ds, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=pin_memory, + pin_memory=self.pin_memory, collate_fn=collate_fn, - persistent_workers=persistent_workers, + persistent_workers=self.persistent_workers, ) def _predict_dataloader(self) -> DataLoader: - """Configure the prediction dataloader of the datamodule.""" - predict_ds: Dataset = self._predict_ds() if isinstance(self._predict_ds, Callable) else self._predict_ds + predict_ds: Input = self._predict_input + collate_fn = self._predict_dataloader_collate_fn - if isinstance(predict_ds, IterableInput): + if isinstance(predict_ds, IterableDataset): batch_size = self.batch_size else: batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) - collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) - pin_memory = True - persistent_workers = False - if isinstance(getattr(self, "trainer", None), pl.Trainer): + if isinstance(self.trainer.lightning_module, flash.Task): + self.connect(self.trainer.lightning_module) return self.trainer.lightning_module.process_predict_dataset( predict_ds, batch_size=batch_size, num_workers=self.num_workers, - pin_memory=pin_memory, + pin_memory=self.pin_memory, collate_fn=collate_fn, ) @@ -394,11 +308,132 @@ def _predict_dataloader(self) -> DataLoader: predict_ds, batch_size=batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, collate_fn=collate_fn, - persistent_workers=persistent_workers, + persistent_workers=self.persistent_workers, ) + def connect(self, task: "flash.Task"): + data_pipeline_state = DataPipelineState() + for properties in [ + self._train_input, + self._val_input, + self._test_input, + self._predict_input, + getattr(self._train_input, "transform", None), + getattr(self._val_input, "transform", None), + getattr(self._test_input, "transform", None), + getattr(self._predict_input, "transform", None), + task._deserializer, + task._output_transform, + task._output, + task, + ]: + if properties is not None and hasattr(properties, "attach_data_pipeline_state"): + properties.attach_data_pipeline_state(data_pipeline_state) + + def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + if getattr(self, "trainer", None) is None: + return batch + transform = None + if self.trainer.training: + transform = self._train_on_after_batch_transfer_fn + elif self.trainer.validating or self.trainer.sanity_checking: + transform = self._val_on_after_batch_transfer_fn + elif self.trainer.testing: + transform = self._test_on_after_batch_transfer_fn + elif self.trainer.predicting: + transform = self._predict_on_after_batch_transfer_fn + + if transform: + batch = transform(batch) + + return batch + + @property + def viz(self) -> BaseVisualization: + return self._viz or DataModule.configure_data_fetcher() + + @viz.setter + def viz(self, viz: BaseVisualization) -> None: + self._viz = viz + + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + """This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`. + + Override with your custom one. + """ + return BaseDataFetcher() + + @property + def data_fetcher(self) -> BaseDataFetcher: + """This property returns the data fetcher.""" + return self._data_fetcher or DataModule.configure_data_fetcher() + + @data_fetcher.setter + def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None: + self._data_fetcher = data_fetcher + + def _reset_iterator(self, stage: str) -> Iterable[Any]: + iter_name = f"_{stage}_iter" + # num_workers has to be set to 0 to work properly + num_workers = self.num_workers + self.num_workers = 0 + dataloader_fn = getattr(self, f"{stage}_dataloader") + iterator = iter(dataloader_fn()) + self.num_workers = num_workers + setattr(self, iter_name, iterator) + return iterator + + def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool = True) -> None: + """This function is used to handle transforms profiling for batch visualization.""" + # don't show in CI + if os.getenv("FLASH_TESTING", "0") == "1": + return None + iter_name = f"_{stage}_iter" + + if not hasattr(self, iter_name): + self._reset_iterator(stage) + + # list of functions to visualise + if isinstance(func_names, str): + func_names = [func_names] + + iter_dataloader = getattr(self, iter_name) + with self.data_fetcher.enable(): + if reset: + self.data_fetcher.batches[stage] = {} + try: + _ = next(iter_dataloader) + except StopIteration: + iter_dataloader = self._reset_iterator(stage) + _ = next(iter_dataloader) + data_fetcher: BaseVisualization = self.data_fetcher + data_fetcher._show(stage, func_names) + if reset: + self.data_fetcher.batches[stage] = {} + + def show_train_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: + """This function is used to visualize a batch from the train dataloader.""" + stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING] + self._show_batch(stage_name, hooks_names, reset=reset) + + def show_val_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: + """This function is used to visualize a batch from the validation dataloader.""" + stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING] + self._show_batch(stage_name, hooks_names, reset=reset) + + def show_test_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: + """This function is used to visualize a batch from the test dataloader.""" + stage_name: str = _STAGES_PREFIX[RunningStage.TESTING] + self._show_batch(stage_name, hooks_names, reset=reset) + + def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: + """This function is used to visualize a batch from the predict dataloader.""" + stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING] + self._show_batch(stage_name, hooks_names, reset=reset) + @property def num_classes(self) -> Optional[int]: """Property that returns the number of classes of the datamodule if a multiclass task.""" @@ -418,21 +453,14 @@ def multi_label(self) -> Optional[bool]: @property def inputs(self) -> Optional[Union[Input, List[InputBase]]]: """Property that returns the inputs associated with this ``DataModule``.""" - datasets = [self.train_dataset, self.val_dataset, self.test_dataset, self.predict_dataset] - inputs = [ - dataset - for dataset in datasets - if isinstance(dataset, InputBase) - or (isinstance(dataset, SplitDataset) and isinstance(dataset.dataset, InputBase)) - ] - if len(inputs) == 0: - inputs = self._input - return inputs + inputs = [self.train_dataset, self.val_dataset, self.test_dataset, self.predict_dataset] + return [input for input in inputs if input] @property def input_transform(self) -> InputTransform: """Property that returns the input transform class used on input data.""" - return self._input_transform or self.input_transform_cls() + # Find a better way to resolve this. + return getattr(self.train_dataset, "transform", None) or self.input_transform_cls(RunningStage.TRAINING) @property def output_transform(self) -> OutputTransform: diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 4d1ae4af36..08bc5b509e 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -11,49 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools import inspect -import weakref -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union -import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden -from torch.utils.data import DataLoader, IterableDataset -import flash -from flash.core.data.batch import _DeserializeProcessor, _DeserializeProcessorV2 +from flash.core.data.batch import _DeserializeProcessorV2 from flash.core.data.input_transform import _create_collate_input_transform_processors +from flash.core.data.input_transform import InputTransform from flash.core.data.input_transform import InputTransform as NewInputTransform -from flash.core.data.io.input import Input, InputBase, IterableInput -from flash.core.data.io.input_transform import _InputTransformProcessor, DefaultInputTransform, InputTransform +from flash.core.data.io.input import Input, InputBase +from flash.core.data.io.input_transform import _InputTransformProcessorV2 from flash.core.data.io.output import _OutputProcessor, Output from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState -from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _OUTPUT_TRANSFORM_FUNCS, _STAGES_PREFIX -from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0 -from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage - -if not _PL_GREATER_EQUAL_1_5_0: - from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader - -if TYPE_CHECKING: - from flash.core.model import Task - - -class DataLoaderGetter: - """A utility class to be used when patching the ``{stage}_dataloader`` attribute of a LightningModule.""" - - def __init__(self, dataloader): - self.dataloader = dataloader - - # Dummy `__code__` attribute to trick is_overridden - self.__code__ = self.__call__.__code__ - - def __call__(self): - return self.dataloader +from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _OUTPUT_TRANSFORM_FUNCS +from flash.core.utilities.stages import RunningStage class DataPipelineState: @@ -97,7 +71,7 @@ def __init__( ) -> None: self.input = input - self._input_transform_pipeline = input_transform or DefaultInputTransform() + self._input_transform_pipeline = input_transform or InputTransform(RunningStage.TRAINING) self._output_transform = output_transform or OutputTransform() self._output = output or Output() self._deserializer = deserializer or Deserializer() @@ -110,7 +84,9 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> data_pipeline_state = data_pipeline_state or DataPipelineState() if self.input is not None: if isinstance(self.input, list): - [input.attach_data_pipeline_state(data_pipeline_state) for input in self.input] + for input in self.input: + if hasattr(input, "attach_data_pipeline_state"): + input.attach_data_pipeline_state(data_pipeline_state) else: self.input.attach_data_pipeline_state(data_pipeline_state) self._deserializer.attach_data_pipeline_state(data_pipeline_state) @@ -119,10 +95,6 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> self._output.attach_data_pipeline_state(data_pipeline_state) return data_pipeline_state - @property - def example_input(self) -> str: - return self._deserializer.example_input - @staticmethod def _is_overridden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool: """Cropped Version of https://github.com/PyTorchLightning/pytorch- @@ -155,7 +127,7 @@ def _is_overridden_recursive( return DataPipeline._is_overridden_recursive(method_name, process_obj, super_obj) current_code = inspect.unwrap(getattr(process_obj, current_method_name)).__code__ - has_different_code = current_code != getattr(super_obj, method_name).__code__ + has_different_code = current_code != getattr(super_obj, current_method_name).__code__ if not prefix: return has_different_code @@ -165,29 +137,16 @@ def _is_overridden_recursive( def _identity(samples: Sequence[Any]) -> Sequence[Any]: return samples - def deserialize_processor(self) -> _DeserializeProcessor: - if isinstance(self._input_transform_pipeline, NewInputTransform): - return _DeserializeProcessorV2( - self._deserializer, - self._input_transform_pipeline, - self._input_transform_pipeline._per_sample_transform, - [], - ) - return self._create_collate_input_transform_processors(RunningStage.PREDICTING)[0] - - def worker_input_transform_processor( - self, running_stage: RunningStage, collate_fn: Optional[Callable] = None, is_serving: bool = False - ) -> _InputTransformProcessor: - if isinstance(self._input_transform_pipeline, NewInputTransform): - return _create_collate_input_transform_processors(self._input_transform_pipeline, [])[0] - return self._create_collate_input_transform_processors( - running_stage, collate_fn=collate_fn, is_serving=is_serving - )[1] - - def device_input_transform_processor(self, running_stage: RunningStage) -> _InputTransformProcessor: - if isinstance(self._input_transform_pipeline, NewInputTransform): - return _create_collate_input_transform_processors(self._input_transform_pipeline, [])[1] - return self._create_collate_input_transform_processors(running_stage)[2] + def deserialize_processor(self) -> _DeserializeProcessorV2: + return _DeserializeProcessorV2( + self._deserializer, + self._input_transform_pipeline, + self._input_transform_pipeline._per_sample_transform, + [], + ) + + def device_input_transform_processor(self, running_stage: RunningStage) -> _InputTransformProcessorV2: + return _create_collate_input_transform_processors(self._input_transform_pipeline, [])[1] def output_transform_processor(self, running_stage: RunningStage, is_serving=False) -> _OutputTransformProcessor: return self._create_output_transform_processor(running_stage, is_serving=is_serving) @@ -223,241 +182,6 @@ def _resolve_function_hierarchy( return function_name - def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: - if on_device: - return self._identity, collate - return collate, self._identity - - def _create_collate_input_transform_processors( - self, - stage: RunningStage, - collate_fn: Optional[Callable] = None, - is_serving: bool = False, - ) -> Tuple[_DeserializeProcessor, _InputTransformProcessor, _InputTransformProcessor]: - - original_collate_fn = collate_fn - - input_transform: InputTransform = self._input_transform_pipeline - prefix: str = _STAGES_PREFIX[stage] - - if collate_fn is not None: - input_transform._default_collate = collate_fn - - func_names: Dict[str, str] = { - k: self._resolve_function_hierarchy(k, input_transform, stage, InputTransform) - for k in self.INPUT_TRANSFORM_FUNCS - } - - collate_fn: Callable = getattr(input_transform, func_names["collate"]) - - per_batch_transform_overridden: bool = self._is_overridden_recursive( - "per_batch_transform", input_transform, InputTransform, prefix=prefix - ) - - per_sample_transform_on_device_overridden: bool = self._is_overridden_recursive( - "per_sample_transform_on_device", input_transform, InputTransform, prefix=prefix - ) - - collate_in_worker_from_transform: Optional[bool] = getattr( - input_transform, f"_{prefix}_collate_in_worker_from_transform", None - ) - - is_per_overridden = per_batch_transform_overridden and per_sample_transform_on_device_overridden - if collate_in_worker_from_transform is None and is_per_overridden: - raise MisconfigurationException( - f"{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` " - f"are mutually exclusive for stage {stage}" - ) - - if isinstance(collate_in_worker_from_transform, bool): - worker_collate_fn, device_collate_fn = self._make_collates(not collate_in_worker_from_transform, collate_fn) - else: - worker_collate_fn, device_collate_fn = self._make_collates( - per_sample_transform_on_device_overridden, collate_fn - ) - - worker_collate_fn = ( - worker_collate_fn.collate_fn - if isinstance(worker_collate_fn, _InputTransformProcessor) - else worker_collate_fn - ) - - per_sample_transform = getattr(input_transform, func_names["per_sample_transform"]) - - deserialize_processor = _DeserializeProcessor( - self._deserializer, - input_transform, - per_sample_transform, - callbacks=input_transform.callbacks, - ) - worker_input_transform_processor = _InputTransformProcessor( - input_transform, - worker_collate_fn, - self._identity if is_serving else per_sample_transform, - getattr(input_transform, func_names["per_batch_transform"]), - stage, - callbacks=input_transform.callbacks, - ) - worker_input_transform_processor._original_collate_fn = original_collate_fn - device_input_transform_processor = _InputTransformProcessor( - input_transform, - device_collate_fn, - getattr(input_transform, func_names["per_sample_transform_on_device"]), - getattr(input_transform, func_names["per_batch_transform_on_device"]), - stage, - apply_per_sample_transform=device_collate_fn != self._identity, - on_device=True, - callbacks=input_transform.callbacks, - ) - return deserialize_processor, worker_input_transform_processor, device_input_transform_processor - - @staticmethod - def _model_transfer_to_device_wrapper( - func: Callable, input_transform: _InputTransformProcessor, model: "Task", stage: RunningStage - ) -> Callable: - - if not isinstance(func, _StageOrchestrator): - func = _StageOrchestrator(func, model) - func.register_additional_stage(stage, input_transform) - - return func - - @staticmethod - def _model_predict_step_wrapper( - func: Callable, output_transform_processor: _OutputTransformProcessor, model: "Task" - ) -> Callable: - - if not isinstance(func, _StageOrchestrator): - _original = func - func = _StageOrchestrator(func, model) - func._original = _original - func.register_additional_stage(RunningStage.PREDICTING, output_transform_processor) - - return func - - @staticmethod - def _get_dataloader(model: "Task", loader_name: str) -> Tuple[DataLoader, str]: - dataloader, attr_name = None, None - if is_overridden(loader_name, model): - dataloader = getattr(model, loader_name) - attr_name = loader_name - - elif ( - model.trainer - and hasattr(model.trainer, "datamodule") - and model.trainer.datamodule - and is_overridden(loader_name, model.trainer.datamodule, flash.DataModule) - ): - dataloader = getattr(model.trainer.datamodule, loader_name, None) - attr_name = f"trainer.datamodule.{loader_name}" - - elif _PL_GREATER_EQUAL_1_5_0 and model.trainer is not None: - source = getattr(model.trainer._data_connector, f"_{loader_name}_source") - if not source.is_module(): - dataloader = source.dataloader() - attr_name = loader_name - - if dataloader is not None: - # Update source as wrapped loader will be attached to model - source.instance = model - source.name = loader_name - - return dataloader, attr_name - - @staticmethod - def _patch_dataloader(model: "Task", dataloader: Union[Callable, DataLoader], stage: RunningStage): - if isinstance(dataloader, DataLoader): - if _PL_GREATER_EQUAL_1_5_0: - dataloader = DataLoaderGetter(dataloader) - elif _PL_GREATER_EQUAL_1_4_3: - dataloader = _PatchDataLoader(dataloader, _STAGES_PREFIX[stage]) - dataloader.patch(model) - else: - dataloader = _PatchDataLoader(dataloader) - return dataloader - - @staticmethod - def _set_loader(model: "Task", loader_name: str, new_loader: DataLoader) -> None: - """This function is used to set the loader to model and/or datamodule.""" - *intermediates, final_name = loader_name.split(".") - curr_attr = model - - # This relies on python calling all non-integral types by reference. - # It may fail for integral types since those will be called by value. - for intermediate in intermediates: - curr_attr = getattr(curr_attr, intermediate) - - setattr(curr_attr, final_name, new_loader) - setattr(model, final_name, new_loader) - - def _attach_input_transform_to_model( - self, - model: "Task", - stage: Optional[RunningStage] = None, - device_transform_only: bool = False, - is_serving: bool = False, - ) -> None: - device_collate_fn = torch.nn.Identity() - - if not stage: - stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] - - elif isinstance(stage, RunningStage): - stages = [stage] - - for stage in stages: - - loader_name = f"{_STAGES_PREFIX[stage]}_dataloader" - - dataloader, whole_attr_name = self._get_dataloader(model, loader_name) - - if not dataloader: - continue - - if callable(dataloader): - dataloader = dataloader() - - if dataloader is None: - continue - - if isinstance(dataloader, Sequence): - was_seq = True - else: - dataloader = [dataloader] - was_seq = False - - for idx, loader in enumerate(dataloader): - # TODO: See lightning for proper reinstantiation of loader - if isinstance(loader, DataLoader): - dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} - - _, dl_args["collate_fn"], device_collate_fn = self._create_collate_input_transform_processors( - stage=stage, collate_fn=dl_args["collate_fn"], is_serving=is_serving - ) - - if isinstance(dl_args["dataset"], IterableDataset): - del dl_args["sampler"] - - # don't have to reinstantiate loader if just rewrapping devices (happens during detach) - if not device_transform_only: - del dl_args["batch_sampler"] - loader = type(loader)(**dl_args) - - dataloader[idx] = loader - - # don't have to set attribute if rewrapping device part (happens during detach) - if not device_transform_only: - if not was_seq: - dataloader = dataloader[0] - - dataloader = self._patch_dataloader(model, dataloader, stage) - - self._set_loader(model, whole_attr_name, dataloader) - - model.transfer_batch_to_device = self._model_transfer_to_device_wrapper( - model.transfer_batch_to_device, device_collate_fn, model, stage - ) - def _create_output_transform_processor( self, stage: RunningStage, @@ -478,107 +202,6 @@ def _create_output_transform_processor( is_serving=is_serving, ) - def _attach_output_transform_to_model( - self, - model: "Task", - stage: RunningStage, - is_serving: bool = False, - ) -> "Task": - model.predict_step = self._model_predict_step_wrapper( - model.predict_step, self._create_output_transform_processor(stage, is_serving=is_serving), model - ) - return model - - def _attach_to_model( - self, - model: "Task", - stage: RunningStage = None, - is_serving: bool = False, - ): - # not necessary to detach. preprocessing and postprocessing for stage will be overwritten. - self._attach_input_transform_to_model(model, stage) - - if not stage or stage == RunningStage.PREDICTING: - self._attach_output_transform_to_model(model, RunningStage.PREDICTING, is_serving=is_serving) - - def _detach_from_model(self, model: "Task", stage: Optional[RunningStage] = None): - self._detach_input_transform_from_model(model, stage) - - if not stage or stage == RunningStage.PREDICTING: - self._detach_output_transform_from_model(model) - - def _detach_input_transform_from_model(self, model: "Task", stage: Optional[RunningStage] = None): - if not stage: - stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] - elif isinstance(stage, RunningStage): - stages = [stage] - - for stage in stages: - - device_collate = None - if isinstance(model.transfer_batch_to_device, _StageOrchestrator): - device_collate = model.transfer_batch_to_device.unregister_stage(stage) - - # if no additional funmc available: remove wrapper - if model.transfer_batch_to_device.is_empty(): - model.transfer_batch_to_device = model.transfer_batch_to_device.func - - if not device_collate: - device_collate = self._identity - - loader_name = f"{_STAGES_PREFIX[stage]}_dataloader" - - dataloader, whole_attr_name = self._get_dataloader(model, loader_name) - - if not dataloader: - continue - - if callable(dataloader): - dataloader = dataloader() - - if isinstance(dataloader, Sequence): - was_seq = True - else: - dataloader = [dataloader] - was_seq = False - - for idx, loader in enumerate(dataloader): - if isinstance(loader, DataLoader): - dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} - - # TODO: Remove the partial function once resolved on Lightning side. - if isinstance(dl_args["collate_fn"], partial): - default_collate = dl_args["collate_fn"].keywords.get("default_collate", None) - if default_collate: - dl_args["collate_fn"] = default_collate - - if isinstance(dl_args["collate_fn"], _InputTransformProcessor): - dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn - - if isinstance(dl_args["dataset"], (IterableInput, IterableDataset)): - del dl_args["sampler"] - - del dl_args["batch_sampler"] - - loader = type(loader)(**dl_args) - - dataloader[idx] = loader - - if not was_seq: - dataloader = dataloader[0] - - dataloader = self._patch_dataloader(model, dataloader, stage) - - self._set_loader(model, whole_attr_name, dataloader) - - @staticmethod - def _detach_output_transform_from_model(model: "Task"): - - if hasattr(model.predict_step, "_original"): - # don't delete the predict_step here since we don't know - # if any other pipeline is attached which may rely on this! - model.predict_step = model.predict_step._original - def __str__(self) -> str: input: Input = self.input input_transform: InputTransform = self._input_transform_pipeline @@ -593,44 +216,3 @@ def __str__(self) -> str: f"output_transform={output_transform}, " f"output={output})" ) - - -class _StageOrchestrator: - def __init__(self, func_to_wrap: Callable, model: "Task") -> None: - self.func = func_to_wrap - - self._stage_mapping = {k: None for k in RunningStage} - self.model = weakref.proxy(model) - - functools.update_wrapper(self, self.func) - - def __call__(self, *args, **kwargs): - outputs = self.func(*args, **kwargs) - - try: - stage = self.model.trainer._running_stage - except AttributeError: - stage = self.model.trainer.state.stage - - internal_running_state = _RUNNING_STAGE_MAPPING[stage] - additional_func = self._stage_mapping.get(internal_running_state, None) - - if additional_func: - outputs = additional_func(outputs) - - return outputs - - def register_additional_stage(self, stage: RunningStage, stage_func: Optional[Callable] = None): - assert stage_func is None or callable(stage_func) - - self._stage_mapping[stage] = stage_func.to(self.model.device, self.model.dtype) - - def unregister_stage(self, stage: RunningStage): - ret_val = self._stage_mapping.pop(stage) - self._stage_mapping[stage] = None - if ret_val: - ret_val = ret_val.cpu() - return ret_val - - def is_empty(self): - return all(v is None for v in self._stage_mapping.values()) or not self._stage_mapping diff --git a/flash/core/data/input_transform.py b/flash/core/data/input_transform.py index c6356ad089..6c1bf2e962 100644 --- a/flash/core/data/input_transform.py +++ b/flash/core/data/input_transform.py @@ -134,22 +134,22 @@ def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: def per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for all stages stage. - The input data of the transform would have the following form: + The input data of the transform would have the following form:: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) @@ -169,12 +169,13 @@ def target_per_sample_transform(self) -> Callable: def train_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the training stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } """ return self._identity @@ -191,21 +192,22 @@ def train_target_per_sample_transform(self) -> Callable: def val_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the validating stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) @@ -225,12 +227,13 @@ def val_target_per_sample_transform(self) -> Callable: def test_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the testing stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } """ return self._identity @@ -247,21 +250,22 @@ def test_target_per_sample_transform(self) -> Callable: def predict_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the predicting stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) @@ -281,21 +285,22 @@ def predict_target_per_sample_transform(self) -> Callable: def serve_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the serving stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) @@ -319,22 +324,22 @@ def serve_target_per_sample_transform(self) -> Callable: def per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for all stages stage. - The input data of the transform would have the following form: + The input data of the transform would have the following form:: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) @@ -354,12 +359,13 @@ def target_per_sample_transform_on_device(self) -> Callable: def train_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the training stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } """ return self._identity @@ -376,21 +382,22 @@ def train_target_per_sample_transform_on_device(self) -> Callable: def val_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the validating stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) @@ -410,12 +417,13 @@ def val_target_per_sample_transform_on_device(self) -> Callable: def test_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the testing stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } """ return self._identity @@ -432,21 +440,22 @@ def test_target_per_sample_transform_on_device(self) -> Callable: def predict_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the predicting stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) @@ -470,21 +479,22 @@ def predict_target_per_sample_transform_on_device(self) -> Callable: def per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for all stages stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) @@ -504,12 +514,13 @@ def target_per_batch_transform(self) -> Callable: def train_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the training stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } """ return self._identity @@ -526,21 +537,22 @@ def train_target_per_batch_transform(self) -> Callable: def val_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the validating stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) @@ -560,12 +572,13 @@ def val_target_per_batch_transform(self) -> Callable: def test_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the testing stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } """ return self._identity @@ -582,21 +595,22 @@ def test_target_per_batch_transform(self) -> Callable: def predict_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the predicting stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) @@ -616,21 +630,22 @@ def predict_target_per_batch_transform(self) -> Callable: def serve_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the serving stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) @@ -654,21 +669,22 @@ def serve_target_per_batch_transform(self) -> Callable: def per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for all stages stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) @@ -688,12 +704,13 @@ def target_per_batch_transform_on_device(self) -> Callable: def train_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the training stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } """ return self._identity @@ -710,21 +727,22 @@ def train_target_per_batch_transform_on_device(self) -> Callable: def val_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the validating stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) @@ -744,12 +762,13 @@ def val_target_per_batch_transform_on_device(self) -> Callable: def test_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the testing stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } """ return self._identity @@ -766,21 +785,22 @@ def test_target_per_batch_transform_on_device(self) -> Callable: def predict_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the predicting stage. - The input data of the transform would have the following form: - { - DataKeys.INPUT: ..., - DataKeys.TARGET: ..., - DataKeys.METADATA: ..., - } + The input data of the transform would have the following form:: + + { + DataKeys.INPUT: ..., + DataKeys.TARGET: ..., + DataKeys.METADATA: ..., + } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: - Example: + .. code-block:: python from flash.core.data.transforms import ApplyToKeys - class MyInputTransform(InputTransform): + class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) @@ -1139,7 +1159,7 @@ def _create_collate_input_transform_processors( worker_input_transform_processor = _InputTransformProcessorV2( input_transform, worker_collate_fn, - input_transform._identity if input_transform.serving else input_transform._per_sample_transform, + input_transform._per_sample_transform, input_transform._per_batch_transform, input_transform.running_stage, callbacks=callbacks, diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index 2b15f99620..1398e8fce8 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -332,12 +332,19 @@ def __next__(self) -> Any: class ServeInput(Input): def __init__( self, + transform: INPUT_TRANSFORM_TYPE = None, + transform_kwargs: Optional[Dict] = None, data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, ) -> None: if hasattr(self, "serve_load_data"): raise MisconfigurationException("`serve_load_data` shouldn't be implemented.") - super().__init__(RunningStage.SERVING, data_pipeline_state=data_pipeline_state) + super().__init__( + RunningStage.SERVING, + transform=transform, + transform_kwargs=transform_kwargs, + data_pipeline_state=data_pipeline_state, + ) def serve_load_sample(self, sample: Any) -> List[Any]: raise NotImplementedError diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index d8aa6458a5..33a8a0bdb9 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -11,600 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from abc import ABC, abstractclassmethod, abstractmethod -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.utils.data._utils.collate import default_collate +import flash from flash.core.data.callback import ControlFlow, FlashCallback -from flash.core.data.io.input import DataKeys, Input -from flash.core.data.process import Deserializer -from flash.core.data.properties import ProcessState, Properties -from flash.core.data.states import ( - CollateFn, - PerBatchTransform, - PerBatchTransformOnDevice, - PerSampleTransform, - PerSampleTransformOnDevice, -) -from flash.core.data.transforms import ApplyToKeys -from flash.core.data.utils import ( - _INPUT_TRANSFORM_FUNCS, - _STAGES_PREFIX, - convert_to_modules, - CurrentFuncContext, - CurrentRunningStageContext, - CurrentRunningStageFuncContext, -) +from flash.core.data.io.input import DataKeys +from flash.core.data.utils import convert_to_modules from flash.core.utilities.stages import RunningStage -class BaseInputTransform(ABC): - @abstractmethod - def get_state_dict(self) -> Dict[str, Any]: - """Override this method to return state_dict.""" - - @abstractclassmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): - """Override this method to load from state_dict.""" - - -class InputTransform(BaseInputTransform, Properties): - """The :class:`~flash.core.data.io.input_transform.InputTransform` encapsulates all the data processing logic - that should run before the data is passed to the model. It is particularly useful when you want to provide an - end to end implementation which works with 4 different stages: ``train``, ``validation``, ``test``, and - inference (``predict``). - - The :class:`~flash.core.data.io.input_transform.InputTransform` supports the following hooks: - - - ``per_sample_transform``: Performs transforms on a single data sample. - Example:: - - * Input: Receive a PIL Image and its label. - - * Action: Rotate the PIL Image and Convert the rotated PIL Image to a tensor. - - * Output: Return the tensored image and its label. - - - ``per_batch_transform``: Performs transforms on a batch. - In this example, we decided not to override the hook. - - - ``per_sample_transform_on_device``: Performs transform on a sample already on a ``GPU`` or ``TPU``. - Example:: - - * Input: Receive a tensored image on device and its label. - - * Action: Apply random transforms. - - * Output: Return an augmented tensored image on device and its label. - - - ``collate``: Converts a sequence of data samples into a batch. - Defaults to ``torch.utils.data._utils.collate.default_collate``. - Example:: - - * Input: Receive a list of augmented tensored images and their respective labels. - - * Action: Collate the list of images into batch. - - * Output: Return a batch of images and their labels. - - - ``per_batch_transform_on_device``: Performs transform on a batch already on ``GPU`` or ``TPU``. - Example:: - - * Input: Receive a batch of images and their labels. - - * Action: Apply normalization on the batch by subtracting the mean - and dividing by the standard deviation from ImageNet. - - * Output: Return a normalized augmented batch of images and their labels. - - .. note:: - - The ``per_sample_transform_on_device`` and ``per_batch_transform`` are mutually exclusive - as it will impact performances. - - Data processing can be configured by overriding hooks or through transforms. The input transforms are given as - a mapping from hook names to callables. Default transforms can be configured by overriding the - ``default_transforms`` or ``{train,val,test,predict}_default_transforms`` methods. These can then be overridden by - the user with the ``{train,val,test,predict}_transform`` arguments to the ``InputTransform``. - All of the hooks can be used in the transform mappings. - - Example:: - - class CustomInputTransform(InputTransform): - - def default_transforms() -> Mapping[str, Callable]: - return { - "per_sample_transform": transforms.ToTensor(), - "collate": torch.utils.data._utils.collate.default_collate, - } - - def train_default_transforms() -> Mapping[str, Callable]: - return { - "per_sample_transform": T.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()]), - "collate": torch.utils.data._utils.collate.default_collate, - } - - When overriding hooks for particular stages, you can prefix with ``train``, ``val``, ``test`` or ``predict``. For - example, you can achieve the same as the above example by implementing ``train_per_sample_transform``. - - Example:: - - class CustomInputTransform(InputTransform): - - def train_per_sample_transform(self, sample: PIL.Image) -> PIL.Image: - return transforms.RandomHorizontalFlip()(sample) - - def collate(self, samples: List[torch.Tensor]) -> torch.Tensor: - return torch.utils.data._utils.collate.default_collate(samples) - - Each hook is aware of the Trainer running stage through booleans. These are useful for adapting functionality for a - stage without duplicating code. - - Example:: - - class CustomInputTransform(InputTransform): - - def per_sample_transform(self, sample: PIL.Image) -> PIL.Image: - - if self.training: - # logic for training - - elif self.validating: - # logic for validation - - elif self.testing: - # logic for testing - - elif self.predicting: - # logic for predicting - """ - - def __init__( - self, - train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - inputs: Optional[Dict[str, "Input"]] = None, - deserializer: Optional["Deserializer"] = None, - default_input: Optional[str] = None, - ): - super().__init__() - - # resolve the default transforms - train_transform = train_transform or self._resolve_transforms(RunningStage.TRAINING) - val_transform = val_transform or self._resolve_transforms(RunningStage.VALIDATING) - test_transform = test_transform or self._resolve_transforms(RunningStage.TESTING) - predict_transform = predict_transform or self._resolve_transforms(RunningStage.PREDICTING) - - # used to keep track of provided transforms - self._train_collate_in_worker_from_transform: Optional[bool] = None - self._val_collate_in_worker_from_transform: Optional[bool] = None - self._predict_collate_in_worker_from_transform: Optional[bool] = None - self._test_collate_in_worker_from_transform: Optional[bool] = None - - # store the transform before conversion to modules. - self.train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) - self.val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) - self.test_transform = self._check_transforms(test_transform, RunningStage.TESTING) - self.predict_transform = self._check_transforms(predict_transform, RunningStage.PREDICTING) - - self._train_transform = convert_to_modules(self.train_transform) - self._val_transform = convert_to_modules(self.val_transform) - self._test_transform = convert_to_modules(self.test_transform) - self._predict_transform = convert_to_modules(self.predict_transform) - - self._inputs = inputs - self._deserializer = deserializer - self._default_input = default_input - self._callbacks: List[FlashCallback] = [] - self._default_collate: Callable = default_collate - - @property - def deserializer(self) -> Optional["Deserializer"]: - return self._deserializer - - def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: - from flash.core.data.data_pipeline import DataPipeline - - resolved_function = getattr( - self, DataPipeline._resolve_function_hierarchy("default_transforms", self, running_stage, InputTransform) - ) - - with CurrentRunningStageFuncContext(running_stage, "default_transforms", self): - transforms: Optional[Dict[str, Callable]] = resolved_function() - return transforms - - def _save_to_state_dict(self, destination, prefix, keep_vars): - input_transform_state_dict = self.get_state_dict() - if not isinstance(input_transform_state_dict, Dict): - raise MisconfigurationException("get_state_dict should return a dictionary") - input_transform_state_dict["_meta"] = {} - input_transform_state_dict["_meta"]["module"] = self.__module__ - input_transform_state_dict["_meta"]["class_name"] = self.__class__.__name__ - input_transform_state_dict["_meta"]["_state"] = self._state - destination["input_transform.state_dict"] = input_transform_state_dict - self._ddp_params_and_buffers_to_ignore = ["input_transform.state_dict"] - return super()._save_to_state_dict(destination, prefix, keep_vars) - - def _check_transforms( - self, transform: Optional[Dict[str, Callable]], stage: RunningStage - ) -> Optional[Dict[str, Callable]]: - if transform is None: - return transform - - if isinstance(transform, list): - transform = {"per_sample_transform": ApplyToKeys(DataKeys.INPUT, torch.nn.Sequential(*transform))} - elif callable(transform): - transform = {"per_sample_transform": ApplyToKeys(DataKeys.INPUT, transform)} - - if not isinstance(transform, Dict): - raise MisconfigurationException( - "Transform should be a dict. " - f"Here are the available keys for your transforms: {_INPUT_TRANSFORM_FUNCS}." - ) - - keys_diff = set(transform.keys()).difference(_INPUT_TRANSFORM_FUNCS) - - if len(keys_diff) > 0: - raise MisconfigurationException( - f"{stage}_transform contains {keys_diff}. Only {_INPUT_TRANSFORM_FUNCS} keys are supported." - ) - - is_per_batch_transform_in = "per_batch_transform" in transform - is_per_sample_transform_on_device_in = "per_sample_transform_on_device" in transform - - if is_per_batch_transform_in and is_per_sample_transform_on_device_in: - raise MisconfigurationException( - f"{transform}: `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive." - ) - - collate_in_worker: Optional[bool] = None - - if is_per_batch_transform_in or (not is_per_batch_transform_in and not is_per_sample_transform_on_device_in): - collate_in_worker = True - - elif is_per_sample_transform_on_device_in: - collate_in_worker = False - - setattr(self, f"_{_STAGES_PREFIX[stage]}_collate_in_worker_from_transform", collate_in_worker) - return transform - - @staticmethod - def _identity(x: Any) -> Any: - return x - - def _get_transform(self, transform: Dict[str, Callable]) -> Callable: - if self.current_fn in transform: - return transform[self.current_fn] - return self._identity - - @property - def current_transform(self) -> Callable: - if self.training and self._train_transform: - return self._get_transform(self._train_transform) - if self.validating and self._val_transform: - return self._get_transform(self._val_transform) - if self.testing and self._test_transform: - return self._get_transform(self._test_transform) - if self.predicting and self._predict_transform: - return self._get_transform(self._predict_transform) - return self._identity - - @property - def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: - """The transforms currently being used by this - :class:`~flash.core.data.io.input_transform.InputTransform`.""" - return { - "train_transform": self.train_transform, - "val_transform": self.val_transform, - "test_transform": self.test_transform, - "predict_transform": self.predict_transform, - } - - @property - def callbacks(self) -> List["FlashCallback"]: - if not hasattr(self, "_callbacks"): - self._callbacks: List[FlashCallback] = [] - return self._callbacks - - @callbacks.setter - def callbacks(self, callbacks: List["FlashCallback"]): - self._callbacks = callbacks - - def add_callbacks(self, callbacks: List["FlashCallback"]): - _callbacks = [c for c in callbacks if c not in self._callbacks] - self._callbacks.extend(_callbacks) - - @staticmethod - def default_transforms() -> Optional[Dict[str, Callable]]: - """The default transforms to use. - - Will be overridden by transforms passed to the ``__init__``. - """ - - def _apply_sample_transform(self, sample: Any) -> Any: - if isinstance(sample, list): - return [self.current_transform(s) for s in sample] - return self.current_transform(sample) - - def _apply_batch_transform(self, batch: Any): - return self.current_transform(batch) - - def _apply_transform_on_sample(self, sample: Any, transform: Callable): - if isinstance(sample, list): - return [transform(s) for s in sample] - - return transform(sample) - - def _apply_transform_on_batch(self, batch: Any, transform: Callable): - return transform(batch) - - def _apply_process_state_transform( - self, - process_state: ProcessState, - sample: Optional[Any] = None, - batch: Optional[Any] = None, - ): - # assert both sample and batch are not None - if sample is None: - assert batch is not None, "sample not provided, batch should not be None" - mode = "batch" - else: - assert batch is None, "sample provided, batch should be None" - mode = "sample" - - process_state_transform = self.get_state(process_state) - - if process_state_transform is not None: - if process_state_transform.transform is not None: - if mode == "sample": - return self._apply_transform_on_sample(sample, process_state_transform.transform) - else: - return self._apply_transform_on_batch(batch, process_state_transform.transform) - else: - if mode == "sample": - return sample - else: - return batch - else: - if mode == "sample": - return self._apply_sample_transform(sample) - else: - return self._apply_batch_transform(batch) - - def per_sample_transform(self, sample: Any) -> Any: - """Transforms to apply on a single object.""" - return self._apply_process_state_transform(PerSampleTransform, sample=sample) - - def per_batch_transform(self, batch: Any) -> Any: - """Transforms to apply to a whole batch (if possible use this for efficiency). - - .. note:: - - This option is mutually exclusive with :meth:`per_sample_transform_on_device`, - since if both are specified, uncollation has to be applied. - """ - return self._apply_process_state_transform(PerBatchTransform, batch=batch) - - def collate(self, samples: Sequence, metadata=None) -> Any: - """Transform to convert a sequence of samples to a collated batch.""" - current_transform = self.current_transform - if current_transform is self._identity: - current_transform = self._default_collate - - # the model can provide a custom ``collate_fn``. - collate_fn = self.get_state(CollateFn) - if collate_fn is not None: - collate_fn = collate_fn.collate_fn - else: - collate_fn = current_transform - # return collate_fn.collate_fn(samples) - - parameters = inspect.signature(collate_fn).parameters - if len(parameters) > 1 and DataKeys.METADATA in parameters: - return collate_fn(samples, metadata) - return collate_fn(samples) - - def per_sample_transform_on_device(self, sample: Any) -> Any: - """Transforms to apply to the data before the collation (per-sample basis). - - .. note:: - - This option is mutually exclusive with :meth:`per_batch_transform`, - since if both are specified, uncollation has to be applied. - - .. note:: - - This function won't be called within the dataloader workers, since to make that happen - each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). - """ - return self._apply_process_state_transform(PerSampleTransformOnDevice, sample=sample) - - def per_batch_transform_on_device(self, batch: Any) -> Any: - """Transforms to apply to a whole batch (if possible use this for efficiency). - - .. note:: - - This function won't be called within the dataloader workers, since to make that happen - each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). - """ - return self._apply_process_state_transform(PerBatchTransformOnDevice, batch=batch) - - def available_inputs(self) -> Sequence[str]: - """Get the list of available data source names for use with this - :class:`~flash.core.data.io.input_transform.InputTransform`. - - Returns: - The list of data source names. - """ - return list(self._inputs.keys()) - - def input_of_name(self, input_name: str) -> Input: - """Get the :class:`~flash.core.data.io.input.Input` of the given name from the - :class:`~flash.core.data.io.input_transform.InputTransform`. - - Args: - input_name: The name of the data source to look up. - - Returns: - The :class:`~flash.core.data.io.input.Input` of the given name. - - Raises: - MisconfigurationException: If the requested data source is not configured by this - :class:`~flash.core.data.io.input_transform.InputTransform`. - """ - if input_name == "default": - input_name = self._default_input - inputs = self._inputs - if input_name in inputs: - return inputs[input_name] - raise MisconfigurationException( - f"No '{input_name}' data source is available for use with the {type(self)}. The available data " - f"sources are: {', '.join(self.available_inputs())}." - ) - - -class DefaultInputTransform(InputTransform): - def __init__( - self, - train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - inputs: Optional[Dict[str, "Input"]] = None, - default_input: Optional[str] = None, - ): - super().__init__( - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - inputs=inputs or {"default": Input}, - default_input=default_input or "default", - ) - - def get_state_dict(self) -> Dict[str, Any]: - return {**self.transforms} - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return cls(**state_dict) - - -class _InputTransformProcessor(torch.nn.Module): - """ - This class is used to encapsulate the following functions of a InputTransformInputTransform Object: - Inside a worker: - per_sample_transform: Function to transform an individual sample - collate: Function to merge sample into a batch - per_batch_transform: Function to transform an individual batch - - Inside main process: - per_sample_transform_on_device: Function to transform an individual sample - collate: Function to merge sample into a batch - per_batch_transform_on_device: Function to transform an individual batch - """ - - def __init__( - self, - input_transform: InputTransform, - collate_fn: Callable, - per_sample_transform: Callable, - per_batch_transform: Callable, - stage: RunningStage, - apply_per_sample_transform: bool = True, - on_device: bool = False, - callbacks: Optional[List[FlashCallback]] = None, - ): - super().__init__() - self.input_transform = input_transform - self.callback = ControlFlow(callbacks or []) - self.collate_fn = convert_to_modules(collate_fn) - self.per_sample_transform = convert_to_modules(per_sample_transform) - self.per_batch_transform = convert_to_modules(per_batch_transform) - self.apply_per_sample_transform = apply_per_sample_transform - self.stage = stage - self.on_device = on_device - - extension = f"{'_on_device' if self.on_device else ''}" - self._current_stage_context = CurrentRunningStageContext(stage, input_transform) - self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform{extension}", input_transform) - self._collate_context = CurrentFuncContext("collate", input_transform) - self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", input_transform) - - @staticmethod - def _extract_metadata( - samples: List[Dict[str, Any]], - ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: - metadata = [s.pop(DataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] - return samples, metadata if any(m is not None for m in metadata) else None - - def forward(self, samples: Sequence[Any]) -> Any: - if not self.on_device: - for sample in samples: - self.callback.on_load_sample(sample, self.stage) - - # we create a new dict to prevent from potential memory leaks - # assuming that the dictionary samples are stored in between and - # potentially modified before the transforms are applied. - if isinstance(samples, dict): - samples = dict(samples.items()) - - with self._current_stage_context: - - if self.apply_per_sample_transform: - with self._per_sample_transform_context: - _samples = [] - - if isinstance(samples, Mapping): - samples = [samples] - - for sample in samples: - sample = self.per_sample_transform(sample) - if self.on_device: - self.callback.on_per_sample_transform_on_device(sample, self.stage) - else: - self.callback.on_per_sample_transform(sample, self.stage) - _samples.append(sample) - - samples = type(_samples)(_samples) - - with self._collate_context: - samples, metadata = self._extract_metadata(samples) - try: - samples = self.collate_fn(samples, metadata) - except TypeError: - samples = self.collate_fn(samples) - if metadata and isinstance(samples, dict): - samples[DataKeys.METADATA] = metadata - self.callback.on_collate(samples, self.stage) - - with self._per_batch_transform_context: - samples = self.per_batch_transform(samples) - if self.on_device: - self.callback.on_per_batch_transform_on_device(samples, self.stage) - else: - self.callback.on_per_batch_transform(samples, self.stage) - return samples - - def __str__(self) -> str: - # todo: define repr function which would take object and string attributes to be shown - return ( - "_InputTransformProcessor:\n" - f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" - f"\t(collate_fn): {str(self.collate_fn)}\n" - f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" - f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n" - f"\t(on_device): {str(self.on_device)}\n" - f"\t(stage): {str(self.stage)}" - ) - - class _InputTransformProcessorV2(torch.nn.Module): """ This class is used to encapsulate the following functions of a InputTransformInputTransform Object: @@ -621,7 +38,7 @@ class _InputTransformProcessorV2(torch.nn.Module): def __init__( self, - input_transform: InputTransform, + input_transform: "flash.core.data.input_transform.InputTransform", collate_fn: Callable, per_sample_transform: Callable, per_batch_transform: Callable, diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py deleted file mode 100644 index 7d8e1864c7..0000000000 --- a/flash/core/data/new_data_module.py +++ /dev/null @@ -1,353 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Mapping, Optional, Type - -import pytorch_lightning as pl -import torch -from pytorch_lightning import LightningDataModule -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate -from torch.utils.data.dataset import IterableDataset -from torch.utils.data.sampler import Sampler - -import flash -from flash.core.data.base_viz import BaseVisualization -from flash.core.data.callback import BaseDataFetcher -from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState -from flash.core.data.input_transform import InputTransform -from flash.core.data.io.input import DataKeys, Input -from flash.core.data.io.input_transform import DefaultInputTransform -from flash.core.data.io.output_transform import OutputTransform -from flash.core.registry import FlashRegistry - - -class DatasetInput(Input): - """The ``DatasetInput`` implements default behaviours for data sources which expect the input to - :meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset` - - Args: - labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the - :class:`~flash.core.data.io.input.ClassificationState`. - """ - - def load_sample(self, sample: Any) -> Mapping[str, Any]: - if isinstance(sample, tuple) and len(sample) == 2: - return {DataKeys.INPUT: sample[0], DataKeys.TARGET: sample[1]} - return {DataKeys.INPUT: sample} - - -class DataModule(DataModule): - """A basic DataModule class for all Flash tasks. This class includes references to a - :class:`~flash.core.data.datasets.Input` and a :class:`~flash.core.data.callback.BaseDataFetcher`. - - Args: - train_input: Input dataset for training. Defaults to None. - val_input: Input dataset for validating model performance during training. Defaults to None. - test_input: Input dataset to test model performance. Defaults to None. - predict_input: Input dataset for predicting. Defaults to None. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to attach to the - :class:`~flash.core.data.io.input_transform.InputTransform`. If ``None``, the output from - :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used. - val_split: An optional float which gives the relative amount of the training dataset to use for the validation - dataset. - batch_size: The batch size to be used by the DataLoader. Defaults to 1. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Windows or Darwin platform. - sampler: A sampler following the :class:`~torch.utils.data.sampler.Sampler` type. - Will be passed to the DataLoader for the training dataset. Defaults to None. - """ - - input_transform_cls = DefaultInputTransform - output_transform_cls = OutputTransform - input_transforms_registry: Optional[FlashRegistry] = None - - def __init__( - self, - train_input: Optional[Input] = None, - val_input: Optional[Input] = None, - test_input: Optional[Input] = None, - predict_input: Optional[Input] = None, - data_fetcher: Optional[BaseDataFetcher] = None, - val_split: Optional[float] = None, - batch_size: Optional[int] = None, - num_workers: int = 0, - sampler: Optional[Type[Sampler]] = None, - pin_memory: bool = True, - persistent_workers: bool = True, - output_transform: Optional[OutputTransform] = None, - ) -> None: - - if not batch_size: - raise MisconfigurationException("The `batch_size` should be provided to the DataModule on instantiation.") - - if flash._IS_TESTING and torch.cuda.is_available(): - batch_size = 16 - - self._input_transform: Optional[OutputTransform] = None - self._output_transform: Optional[OutputTransform] = output_transform - self._viz: Optional[BaseVisualization] = None - - self._train_input = train_input - self._val_input = val_input - self._test_input = test_input - self._predict_input = predict_input - - self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() - - self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input) - self._val_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._val_input) - self._test_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._test_input) - self._predict_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._predict_input) - - self._train_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._train_input) - self._val_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._val_input) - self._test_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._test_input) - self._predict_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._predict_input) - - if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0: - raise MisconfigurationException( - "A `val_dataset` was provided with `val_split`. Please, choose one or the other." - ) - - if self._train_input is not None and (val_split is not None and self._val_input is None): - self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) - - if self._train_input: - self.train_dataloader = self._train_dataloader - - if self._val_input: - self.val_dataloader = self._val_dataloader - - if self._test_input: - self.test_dataloader = self._test_dataloader - - if self._predict_input: - self.predict_dataloader = self._predict_dataloader - - self.batch_size = batch_size - - if num_workers is None: - num_workers = 0 - self.num_workers = num_workers - self.persistent_workers = persistent_workers and num_workers > 0 - self.pin_memory = pin_memory - - self.sampler = sampler - - LightningDataModule.__init__(self) - - @property - def input_transform(self) -> InputTransform: - """Property that returns the input transform class used on input data.""" - # Find a better way to resolve this. - return self._train_input.transform or self.input_transform_cls() - - @property - def train_dataset(self) -> Optional[Input]: - """This property returns the train dataset.""" - return self._train_input - - @property - def val_dataset(self) -> Optional[Input]: - """This property returns the validation dataset.""" - return self._val_input - - @property - def test_dataset(self) -> Optional[Input]: - """This property returns the test dataset.""" - return self._test_input - - @property - def predict_dataset(self) -> Optional[Input]: - """This property returns the predict dataset.""" - return self._predict_input - - def _resolve_transform(self, ds: Optional[Input]) -> Optional[InputTransform]: - if not isinstance(ds, Input): - return None - return ds.transform - - def _resolve_dataloader_collate_fn(self, ds: Optional[Input]) -> Optional[Callable]: - if not ds: - return None - if isinstance(ds.transform, InputTransform): - return ds._create_dataloader_collate_fn([self.data_fetcher]) - return default_collate - - def _resolve_on_after_batch_transfer_fn(self, ds: Optional[Input]) -> Optional[Callable]: - if not ds: - return None - if isinstance(ds.transform, InputTransform): - return ds._create_on_after_batch_transfer_fn([self.data_fetcher]) - - def _train_dataloader(self) -> DataLoader: - train_ds: Input = self._train_input - collate_fn = self._train_dataloader_collate_fn - shuffle: bool = False - if isinstance(train_ds, IterableDataset): - drop_last = False - else: - drop_last = len(train_ds) > self.batch_size - - if self.sampler is None: - sampler = None - shuffle = not isinstance(train_ds, IterableDataset) - else: - sampler = self.sampler(train_ds) - - if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - return self.trainer.lightning_module.process_train_dataset( - train_ds, - trainer=self.trainer, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - shuffle=shuffle, - drop_last=drop_last, - collate_fn=collate_fn, - sampler=sampler, - ) - - return DataLoader( - train_ds, - batch_size=self.batch_size, - shuffle=shuffle, - sampler=sampler, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - drop_last=drop_last, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) - - def _val_dataloader(self) -> DataLoader: - val_ds: Input = self._val_input - collate_fn = self._val_dataloader_collate_fn - - if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - return self.trainer.lightning_module.process_val_dataset( - val_ds, - trainer=self.trainer, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - ) - - return DataLoader( - val_ds, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) - - def _test_dataloader(self) -> DataLoader: - test_ds: Input = self._test_input - collate_fn = self._test_dataloader_collate_fn - - if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - return self.trainer.lightning_module.process_test_dataset( - test_ds, - trainer=self.trainer, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - ) - - return DataLoader( - test_ds, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) - - def _predict_dataloader(self) -> DataLoader: - predict_ds: Input = self._predict_input - collate_fn = self._predict_dataloader_collate_fn - - if isinstance(predict_ds, IterableDataset): - batch_size = self.batch_size - else: - batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) - - if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - return self.trainer.lightning_module.process_predict_dataset( - predict_ds, - batch_size=batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - ) - - return DataLoader( - predict_ds, - batch_size=batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) - - def connect(self, task: "flash.Task"): - data_pipeline_state = DataPipelineState() - for properties in [ - self._train_input, - self._val_input, - self._test_input, - self._predict_input, - getattr(self._train_input, "transform", None), - getattr(self._val_input, "transform", None), - getattr(self._test_input, "transform", None), - getattr(self._predict_input, "transform", None), - task._deserializer, - task._output_transform, - task._output, - task, - ]: - if properties is not None: - properties.attach_data_pipeline_state(data_pipeline_state) - - def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - if getattr(self, "trainer", None) is None: - return batch - transform = None - if self.trainer.training: - transform = self._train_on_after_batch_transfer_fn - elif self.trainer.validating or self.trainer.sanity_checking: - transform = self._val_on_after_batch_transfer_fn - elif self.trainer.testing: - transform = self._test_on_after_batch_transfer_fn - elif self.trainer.predicting: - transform = self._predict_on_after_batch_transfer_fn - - if transform: - batch = transform(batch) - - return batch diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index b502fb3ff2..ec48458d6e 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -42,11 +42,6 @@ } _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict", "serve"} -_INPUT_FUNCS: Set[str] = { - "load_data", - "load_sample", -} - _INPUT_TRANSFORM_FUNCS: Set[str] = { "per_sample_transform", "per_batch_transform", @@ -67,59 +62,6 @@ } -class CurrentRunningStageContext: - def __init__(self, running_stage: RunningStage, obj: Any, reset: bool = True): - self._running_stage = running_stage - self._obj = obj - self._reset = reset - - def __enter__(self): - if self._obj is not None: - if getattr(self._obj, "running_stage", None) != self._running_stage: - self._obj.running_stage = self._running_stage - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - if self._obj is not None and self._reset: - self._obj.running_stage = None - - -class CurrentFuncContext: - def __init__(self, current_fn: str, obj: Any): - self._current_fn = current_fn - self._obj = obj - - def __enter__(self): - if self._obj is not None: - if getattr(self._obj, "current_fn", None) != self._current_fn: - self._obj.current_fn = self._current_fn - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - if self._obj is not None: - self._obj.current_fn = None - - -class CurrentRunningStageFuncContext: - def __init__(self, running_stage: RunningStage, current_fn: str, obj: Any): - self._running_stage = running_stage - self._current_fn = current_fn - self._obj = obj - - def __enter__(self): - if self._obj is not None: - if getattr(self._obj, "running_stage", None) != self._running_stage: - self._obj.running_stage = self._running_stage - if getattr(self._obj, "current_fn", None) != self._current_fn: - self._obj.current_fn = self._current_fn - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - if self._obj is not None: - self._obj.running_stage = None - self._obj.current_fn = None - - def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: """Download file with progressbar. diff --git a/flash/core/model.py b/flash/core/model.py index 2328fb8910..ee6820478e 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -38,11 +38,10 @@ import flash from flash.core.data.data_pipeline import DataPipeline, DataPipelineState -from flash.core.data.io.input import Input, InputBase -from flash.core.data.io.input_transform import InputTransform +from flash.core.data.input_transform import InputTransform +from flash.core.data.io.input import InputBase from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.new_data_module import DataModule as NewDataModule from flash.core.data.process import Deserializer, DeserializerMapping from flash.core.data.properties import ProcessState from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, _FINETUNING_STRATEGIES_REGISTRY @@ -743,12 +742,6 @@ def build_data_pipeline( input = input or old_input - if isinstance(input, str): - if input_transform is None: - input = Input() # TODO: warn the user that we are not using the specified data source - else: - input = input_transform.input_of_name(input) - if deserializer is None or type(deserializer) is Deserializer: deserializer = getattr(input_transform, "deserializer", deserializer) @@ -810,56 +803,9 @@ def output_transform(self) -> OutputTransform: return getattr(self.data_pipeline, "_output_transform", None) def on_predict_start(self) -> None: - if self.trainer and isinstance(self.trainer.datamodule, NewDataModule) and not self._wrapped_predict_step: + if self.trainer and not self._wrapped_predict_step: self.predict_step = self._wrap_predict_step(self.predict_step) - def on_train_dataloader(self) -> None: - # TODO: Remove this logic when moving to the new DataModule - if self.trainer and isinstance(self.trainer.datamodule, NewDataModule): - return - if self.data_pipeline is not None: - self.data_pipeline._detach_from_model(self, RunningStage.TRAINING) - self.data_pipeline._attach_to_model(self, RunningStage.TRAINING) - super().on_train_dataloader() - - def on_val_dataloader(self) -> None: - if self.trainer and isinstance(self.trainer.datamodule, NewDataModule): - return - if self.data_pipeline is not None: - self.data_pipeline._detach_from_model(self, RunningStage.VALIDATING) - self.data_pipeline._attach_to_model(self, RunningStage.VALIDATING) - super().on_val_dataloader() - - def on_test_dataloader(self, *_) -> None: - if self.trainer and isinstance(self.trainer.datamodule, NewDataModule): - return - if self.data_pipeline is not None: - self.data_pipeline._detach_from_model(self, RunningStage.TESTING) - self.data_pipeline._attach_to_model(self, RunningStage.TESTING) - super().on_test_dataloader() - - def on_predict_dataloader(self) -> None: - if self.trainer and isinstance(self.trainer.datamodule, NewDataModule): - return - if self.data_pipeline is not None: - self.data_pipeline._detach_from_model(self, RunningStage.PREDICTING) - self.data_pipeline._attach_to_model(self, RunningStage.PREDICTING) - super().on_predict_dataloader() - - def on_predict_end(self) -> None: - if self.trainer and isinstance(self.trainer.datamodule, NewDataModule): - return - if self.data_pipeline is not None: - self.data_pipeline._detach_from_model(self) - super().on_predict_end() - - def on_fit_end(self) -> None: - if self.trainer and isinstance(self.trainer.datamodule, NewDataModule): - return - if self.data_pipeline is not None: - self.data_pipeline._detach_from_model(self) - super().on_fit_end() - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index 0149025979..ae111432d2 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -55,12 +55,9 @@ def __init__(self, model): self.model = model self.model.eval() self.data_pipeline = model.build_data_pipeline() - self.worker_input_transform_processor = self.data_pipeline.worker_input_transform_processor( - RunningStage.PREDICTING, is_serving=True - ) - self.device_input_transform_processor = self.data_pipeline.device_input_transform_processor( - RunningStage.PREDICTING - ) + self.deserializer = self.data_pipeline._deserializer + self.dataloader_collate_fn = self.data_pipeline._deserializer._create_dataloader_collate_fn([]) + self.on_after_batch_transfer_fn = self.data_pipeline._deserializer._create_on_after_batch_transfer_fn([]) self.output_transform_processor = self.data_pipeline.output_transform_processor( RunningStage.PREDICTING, is_serving=True ) @@ -69,17 +66,17 @@ def __init__(self, model): self.device = self.model.device @expose( - inputs={"inputs": FlashInputs(data_pipeline.deserialize_processor())}, + inputs={"inputs": FlashInputs(data_pipeline._deserializer._call_load_sample)}, outputs={"outputs": FlashOutputs(data_pipeline.output_processor())}, ) def predict(self, inputs): with torch.no_grad(): - inputs = self.worker_input_transform_processor(inputs) + inputs = self.dataloader_collate_fn(inputs) if self.extra_arguments: inputs = self.model.transfer_batch_to_device(inputs, self.device, 0) else: inputs = self.model.transfer_batch_to_device(inputs, self.device) - inputs = self.device_input_transform_processor(inputs) + inputs = self.on_after_batch_transfer_fn(inputs) preds = self.model.predict_step(inputs, 0) preds = self.output_transform_processor(preds) return preds diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 34cce32db1..f969fc8e34 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -15,7 +15,7 @@ import warnings from argparse import ArgumentParser, Namespace from functools import wraps -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from pytorch_lightning import LightningDataModule, LightningModule @@ -24,12 +24,11 @@ from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden from torch.utils.data import DataLoader import flash from flash.core.model import Task -from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _SERVE_AVAILABLE +from flash.core.utilities.imports import _SERVE_AVAILABLE def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): @@ -237,58 +236,3 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> # the lightning trainer implementation does not support subclasses. # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447 return from_argparse_args(Trainer, args, **kwargs) - - def _parse_request_dataloader_args(self, args: Tuple, kwargs: Dict): - """Handles backwards compatibility for ``request_dataloader``. - - Possible combinations: - - legacy: (model, stage) - (stage, model) - (stage, model=model) - """ - model, stage, is_legacy = None, None, False - if len(args) == 2: - # Check for legacy arguments: (model, stage) - if isinstance(args[0], LightningModule): - is_legacy = True - model, stage = args - else: # (stage, model) - stage, model = args - else: - stage = kwargs.get("stage", args[0]) - model = kwargs.get("model") - return model, stage, is_legacy - - def request_dataloader( - self, - *args, - **kwargs, - ) -> Union[DataLoader, List[DataLoader]]: - """Handles downloading data in the GPU or TPU case. - - Returns: - The dataloader - """ - model, stage, is_legacy = self._parse_request_dataloader_args(args, kwargs) - - if is_legacy: - self.call_hook(f"on_{stage}_dataloader") - dataloader = getattr(model, f"{stage}_dataloader")() - else: - hook = f"{stage.dataloader_prefix}_dataloader" - self.call_hook("on_" + hook, pl_module=model) - - if is_overridden(hook, model): - dataloader = self.call_hook(hook, pl_module=model) - elif _PL_GREATER_EQUAL_1_5_0: - source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source") - dataloader = source.dataloader() - - if isinstance(dataloader, tuple): - dataloader = list(dataloader) - if _PL_GREATER_EQUAL_1_5_0: - self.training_type_plugin.barrier("get_dataloaders") - else: - self.accelerator.barrier("get_dataloaders") - return dataloader diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py index ea1883b75b..46a2b25842 100644 --- a/flash/graph/classification/data.py +++ b/flash/graph/classification/data.py @@ -15,9 +15,9 @@ from torch.utils.data import Dataset +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.graph.classification.input import GraphClassificationDatasetInput diff --git a/flash/graph/classification/input.py b/flash/graph/classification/input.py index 13733b988d..680418d210 100644 --- a/flash/graph/classification/input.py +++ b/flash/graph/classification/input.py @@ -15,9 +15,9 @@ from torch.utils.data import Dataset +from flash.core.data.data_module import DatasetInput from flash.core.data.io.classification_input import ClassificationInput, ClassificationState from flash.core.data.io.input import DataKeys -from flash.core.data.new_data_module import DatasetInput from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires if _GRAPH_AVAILABLE: diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 60f954b86f..d7dd7b59e8 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -20,10 +20,10 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule, DatasetInput from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.io.input import DataKeys, Input -from flash.core.data.new_data_module import DataModule, DatasetInput from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioImageClassificationInput from flash.core.registry import FlashRegistry diff --git a/flash/image/classification/integrations/baal/data.py b/flash/image/classification/integrations/baal/data.py index 56553350ad..2145cf62c9 100644 --- a/flash/image/classification/integrations/baal/data.py +++ b/flash/image/classification/integrations/baal/data.py @@ -19,9 +19,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader, Dataset, random_split +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipeline from flash.core.data.io.input import InputBase -from flash.core.data.new_data_module import DataModule from flash.core.utilities.imports import _BAAL_AVAILABLE, requires if _BAAL_AVAILABLE: diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index b1335f9b29..9425dd01a8 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -13,9 +13,9 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.integrations.icevision.data import IceVisionInput from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index a46f46f36b..dcf9b4702d 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -15,9 +15,9 @@ from torch.utils.data import Dataset +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.classification.data import ImageClassificationFilesInput, ImageClassificationFolderInput diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 6cd624e3ae..c08b2e7f2d 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -13,10 +13,10 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Type +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import DataKeys, Input from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.new_data_module import DataModule from flash.core.integrations.icevision.data import IceVisionInput from flash.core.integrations.icevision.transforms import IceVisionInputTransform as InstanceSegmentationInputTransform from flash.core.utilities.imports import _ICEVISION_AVAILABLE diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 799fef7520..c3b8ef761b 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -13,9 +13,9 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Type +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.integrations.icevision.data import IceVisionInput from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.utilities.imports import _ICEVISION_AVAILABLE diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index eb59a25c64..60a8477577 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -17,9 +17,9 @@ import torch from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import from flash.core.utilities.stages import RunningStage diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 159118682e..a42e37e096 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -16,9 +16,9 @@ import numpy as np import torch +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.classification.input import ImageClassificationFilesInput, ImageClassificationFolderInput diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index a056debbe6..116ec52218 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -15,10 +15,10 @@ from torch.utils.data import Dataset +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import InputTransform from flash.core.data.io.input import BaseDataFormat, Input -from flash.core.data.new_data_module import DataModule from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.pointcloud.detection.input import PointCloudObjectDetectorDatasetInput diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 3796f293ac..f87345e241 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -15,10 +15,10 @@ from torch.utils.data import Dataset +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import InputTransform from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.pointcloud.segmentation.input import PointCloudSegmentationDatasetInput, PointCloudSegmentationFoldersInput diff --git a/flash/tabular/data.py b/flash/tabular/data.py index 1fde1fc62b..23dcd1e4de 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -13,11 +13,11 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Type, Union +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.data.io.input import Input from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.new_data_module import DataModule from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.tabular.input import TabularCSVInput, TabularDataFrameInput diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index c36dc98efb..e4cfbd078c 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -16,11 +16,11 @@ from torch.utils.data.sampler import Sampler from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.data.io.input import Input from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.new_data_module import DataModule from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.tabular.forecasting.input import TabularForecastingDataFrameInput diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index 0374a876d9..4bb1e93755 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -18,11 +18,11 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import InputTransform from flash.core.data.io.classification_input import ClassificationInput from flash.core.data.io.input import DataKeys, Input -from flash.core.data.new_data_module import DataModule from flash.core.data.utilities.samples import to_samples from flash.core.utilities.imports import _SKLEARN_AVAILABLE from flash.core.utilities.stages import RunningStage diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 1e2fcbfd16..cadad286d9 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -15,9 +15,9 @@ from pandas.core.frame import DataFrame +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioTextClassificationInput from flash.core.integrations.transformers.input_transform import TransformersInputTransform diff --git a/flash/text/input.py b/flash/text/input.py index 977c15fd1f..96a3171627 100644 --- a/flash/text/input.py +++ b/flash/text/input.py @@ -20,8 +20,8 @@ class TextDeserializer(Deserializer): @requires("text") - def __init__(self, max_length: int = 128): - super().__init__() + def __init__(self, *args, max_length: int = 128, **kwargs): + super().__init__(*args, **kwargs) self.max_length = max_length def serve_load_sample(self, text: str) -> Tensor: diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index d0063cdbd9..5a71b73bc1 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -13,9 +13,9 @@ # limitations under the License. from typing import Any, Dict, Optional, Type, Union +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.data.utilities.paths import PATH_TYPE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 5f6e552239..4e7f4ca09a 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -13,9 +13,9 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Type, Union +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.transformers.input_transform import TransformersInputTransform from flash.core.utilities.imports import _TEXT_AVAILABLE diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 3a3d7ec41f..6af825c80a 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -16,10 +16,10 @@ import torch from torch.utils.data import Sampler +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioVideoClassificationInput from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import, requires from flash.core.utilities.stages import RunningStage diff --git a/flash_examples/flash_components/custom_data_loading.py b/flash_examples/flash_components/custom_data_loading.py index c2df25462b..a3a34bc0b2 100644 --- a/flash_examples/flash_components/custom_data_loading.py +++ b/flash_examples/flash_components/custom_data_loading.py @@ -23,10 +23,10 @@ from pytorch_lightning import seed_everything from flash import _PACKAGE_ROOT, RunningStage +from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.data.io.input import DataKeys, Input -from flash.core.data.new_data_module import DataModule from flash.core.data.utils import download_data seed_everything(42) diff --git a/flash_examples/serve/image_classification/inference_server.py b/flash_examples/serve/image_classification/inference_server.py index 00a519a63b..9b173e1c14 100644 --- a/flash_examples/serve/image_classification/inference_server.py +++ b/flash_examples/serve/image_classification/inference_server.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash.image import ImageClassifier +from flash import RunningStage +from flash.image import ImageClassificationInputTransform, ImageClassifier +from flash.image.data import ImageDeserializer model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt" ) +model.deserializer = ImageDeserializer(transform=ImageClassificationInputTransform(RunningStage.SERVING)) model.serve() diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index 7869639eb4..f817a8c561 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -22,9 +22,9 @@ from flash import Trainer from flash.__main__ import main from flash.audio import SpeechRecognition -from flash.audio.speech_recognition.data import InputTransform, SpeechRecognitionOutputTransform +from flash.audio.speech_recognition.data import InputTransform, SpeechRecognitionData, SpeechRecognitionOutputTransform from flash.audio.speech_recognition.input import SpeechRecognitionDeserializer -from flash.core.data.io.input import DataKeys +from flash.core.data.io.input import DataKeys, Input from flash.core.utilities.imports import _AUDIO_AVAILABLE from flash.core.utilities.stages import RunningStage from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING @@ -53,9 +53,11 @@ def __len__(self) -> int: @pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_init_train(tmpdir): model = SpeechRecognition(backbone=TEST_BACKBONE) - train_dl = torch.utils.data.DataLoader(DummyDataset()) + datamodule = SpeechRecognitionData( + Input(RunningStage.TRAINING, DummyDataset(), transform=InputTransform), batch_size=2 + ) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) + trainer.fit(model, datamodule=datamodule) @pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") @@ -82,10 +84,10 @@ def test_jit(tmpdir): def test_serve(): model = SpeechRecognition(backbone=TEST_BACKBONE) - # TODO: Currently only servable once a input_transform and postprocess have been attached - model._input_transform = InputTransform(RunningStage.SERVING) + model._deserializer = SpeechRecognitionDeserializer(transform=InputTransform(RunningStage.SERVING)) + # TODO: Serve should share the state + model._deserializer.transform._state = model._state model._output_transform = SpeechRecognitionOutputTransform() - model._deserializer = SpeechRecognitionDeserializer() model.eval() model.serve() diff --git a/tests/core/data/io/test_input_transform.py b/tests/core/data/io/test_input_transform.py deleted file mode 100644 index 46b7c57b50..0000000000 --- a/tests/core/data/io/test_input_transform.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest.mock import Mock - -import pytest -import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.utils.data._utils.collate import default_collate - -from flash.core.data.io.input import InputFormat -from flash.core.data.io.input_transform import _InputTransformProcessor, DefaultInputTransform -from flash.core.utilities.stages import RunningStage - - -class CustomInputTransform(DefaultInputTransform): - def __init__(self): - super().__init__( - inputs={ - "test": Mock(return_value="test"), - InputFormat.TENSORS: Mock(return_value="tensors"), - }, - default_input="test", - ) - - -def test_input_transform_processor_str(): - input_transform_processor = _InputTransformProcessor( - Mock(name="input_transform"), - default_collate, - torch.relu, - torch.softmax, - RunningStage.TRAINING, - False, - True, - ) - assert str(input_transform_processor) == ( - "_InputTransformProcessor:\n" - "\t(per_sample_transform): FuncModule(relu)\n" - "\t(collate_fn): FuncModule(default_collate)\n" - "\t(per_batch_transform): FuncModule(softmax)\n" - "\t(apply_per_sample_transform): False\n" - "\t(on_device): True\n" - "\t(stage): RunningStage.TRAINING" - ) - - -def test_input_of_name(): - input_transform = CustomInputTransform() - - assert input_transform.input_of_name("test")() == "test" - assert input_transform.input_of_name(InputFormat.TENSORS)() == "tensors" - assert input_transform.input_of_name("tensors")() == "tensors" - assert input_transform.input_of_name("default")() == "test" - - with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"): - input_transform.input_of_name("not available") - - -def test_check_transforms(): - transform = torch.nn.Identity() - DefaultInputTransform(train_transform=transform) - DefaultInputTransform(train_transform=[transform]) diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py index 258c0a75cc..543e2aac33 100644 --- a/tests/core/data/io/test_output.py +++ b/tests/core/data/io/test_output.py @@ -17,10 +17,11 @@ import torch from torch.utils.data import DataLoader +from flash import RunningStage from flash.core.classification import LabelsOutput from flash.core.data.data_pipeline import DataPipeline, DataPipelineState +from flash.core.data.input_transform import InputTransform from flash.core.data.io.classification_input import ClassificationState -from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.io.output import Output from flash.core.model import Task from flash.core.trainer import Trainer @@ -47,10 +48,10 @@ def __init__(self): output = LabelsOutput(["a", "b"]) model = CustomModel() trainer = Trainer(fast_dev_run=True) - data_pipeline = DataPipeline(input_transform=DefaultInputTransform(), output=output) + data_pipeline = DataPipeline(input_transform=InputTransform(RunningStage.TRAINING), output=output) data_pipeline.initialize() model.data_pipeline = data_pipeline - assert isinstance(model.input_transform, DefaultInputTransform) + assert isinstance(model.input_transform, InputTransform) dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, train_dataloader=dummy_data) trainer.save_checkpoint(checkpoint_file) diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index a7eebea2be..d74fc8d897 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -17,8 +17,8 @@ import torch from flash import DataKeys +from flash.core.data.data_module import DataModule, DatasetInput from flash.core.data.input_transform import InputTransform -from flash.core.data.new_data_module import DataModule, DatasetInput from flash.core.model import Task from flash.core.trainer import Trainer from flash.core.utilities.stages import RunningStage diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py index cd7c69cc17..346c0fdbfd 100644 --- a/tests/core/data/test_callbacks.py +++ b/tests/core/data/test_callbacks.py @@ -17,9 +17,9 @@ from torch import tensor from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule from flash.core.data.input_transform import InputTransform from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule from flash.core.utilities.stages import RunningStage diff --git a/tests/core/data/test_new_data_module.py b/tests/core/data/test_data_module.py similarity index 94% rename from tests/core/data/test_new_data_module.py rename to tests/core/data/test_data_module.py index ccdfedf462..45e6185c7e 100644 --- a/tests/core/data/test_new_data_module.py +++ b/tests/core/data/test_data_module.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass from typing import Callable, Dict +from unittest import mock import numpy as np import pytest @@ -21,9 +22,9 @@ from torch.utils.data import Dataset from flash import Task, Trainer +from flash.core.data.data_module import DataModule, DatasetInput from flash.core.data.input_transform import InputTransform from flash.core.data.io.input import Input -from flash.core.data.new_data_module import DataModule, DatasetInput from flash.core.data.states import PerBatchTransformOnDevice, PerSampleTransform from flash.core.utilities.imports import _TORCHVISION_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -417,3 +418,24 @@ def validation_step(self, batch, batch_idx): num_sanity_val_steps=1, ) trainer.fit(model, datamodule=datamodule) + + +@mock.patch("flash.core.data.data_module.DataLoader") +def test_dataloaders_with_sampler(mock_dataloader): + mock_sampler = mock.MagicMock() + datamodule = DataModule( + TestInput(RunningStage.TRAINING, [1]), + TestInput(RunningStage.VALIDATING, [1]), + TestInput(RunningStage.TESTING, [1]), + batch_size=2, + num_workers=0, + sampler=mock_sampler, + ) + assert datamodule.sampler is mock_sampler + dl = datamodule.train_dataloader() + kwargs = mock_dataloader.call_args[1] + assert "sampler" in kwargs + assert kwargs["sampler"] is mock_sampler.return_value + for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]: + kwargs = mock_dataloader.call_args[1] + assert "sampler" not in kwargs diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index dd58bbb879..f820438c43 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -11,25 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Any, cast, Dict, Optional, Tuple +from typing import cast, Tuple import pytest import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor, tensor -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate +from torch import Tensor -from flash import Trainer -from flash.core.data.data_pipeline import _StageOrchestrator, DataPipeline, DataPipelineState +from flash.core.data.data_pipeline import DataPipeline, DataPipelineState +from flash.core.data.input_transform import InputTransform from flash.core.data.io.input import Input -from flash.core.data.io.input_transform import _InputTransformProcessor, DefaultInputTransform, InputTransform from flash.core.data.io.output import Output -from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform +from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState -from flash.core.model import Task from flash.core.utilities.stages import RunningStage @@ -71,442 +66,19 @@ def test_data_pipeline_str(): assert str(data_pipeline) == (f"DataPipeline({expected})") -@pytest.mark.parametrize("use_input_transform", [False, True]) -@pytest.mark.parametrize("use_output_transform", [False, True]) -def test_data_pipeline_init_and_assignement(use_input_transform, use_output_transform, tmpdir): - class CustomModel(Task): - def __init__(self, output_transform: Optional[OutputTransform] = None): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - self._output_transform = output_transform - - def train_dataloader(self) -> Any: - return DataLoader(DummyDataset()) - - class SubInputTransform(DefaultInputTransform): - pass - - class SubOutputTransform(OutputTransform): - pass - - data_pipeline = DataPipeline( - input_transform=SubInputTransform() if use_input_transform else None, - output_transform=SubOutputTransform() if use_output_transform else None, - ) - assert isinstance( - data_pipeline._input_transform_pipeline, SubInputTransform if use_input_transform else DefaultInputTransform - ) - assert isinstance(data_pipeline._output_transform, SubOutputTransform if use_output_transform else OutputTransform) - - model = CustomModel(output_transform=OutputTransform()) - model.data_pipeline = data_pipeline - # TODO: the line below should make the same effect but it's not - # data_pipeline._attach_to_model(model) - - if use_input_transform: - assert isinstance(model._input_transform, SubInputTransform) - else: - assert model._input_transform is None or isinstance(model._input_transform, InputTransform) - - if use_output_transform: - assert isinstance(model._output_transform, SubOutputTransform) - else: - assert model._output_transform is None or isinstance(model._output_transform, OutputTransform) - - -def test_data_pipeline_is_overridden_and_resolve_function_hierarchy(tmpdir): - class CustomInputTransform(DefaultInputTransform): - def val_per_sample_transform(self, *_, **__): - pass - - def test_collate(self, *_, **__): - pass - - def val_per_sample_transform_on_device(self, *_, **__): - pass - - def train_per_batch_transform_on_device(self, *_, **__): - pass - - def test_per_batch_transform_on_device(self, *_, **__): - pass - - input_transform = CustomInputTransform() - data_pipeline = DataPipeline(input_transform=input_transform) - - train_func_names: Dict[str, str] = { - k: data_pipeline._resolve_function_hierarchy( - k, data_pipeline._input_transform_pipeline, RunningStage.TRAINING, InputTransform - ) - for k in data_pipeline.INPUT_TRANSFORM_FUNCS - } - val_func_names: Dict[str, str] = { - k: data_pipeline._resolve_function_hierarchy( - k, data_pipeline._input_transform_pipeline, RunningStage.VALIDATING, InputTransform - ) - for k in data_pipeline.INPUT_TRANSFORM_FUNCS - } - test_func_names: Dict[str, str] = { - k: data_pipeline._resolve_function_hierarchy( - k, data_pipeline._input_transform_pipeline, RunningStage.TESTING, InputTransform - ) - for k in data_pipeline.INPUT_TRANSFORM_FUNCS - } - predict_func_names: Dict[str, str] = { - k: data_pipeline._resolve_function_hierarchy( - k, data_pipeline._input_transform_pipeline, RunningStage.PREDICTING, InputTransform - ) - for k in data_pipeline.INPUT_TRANSFORM_FUNCS - } - - # per_sample_transform - assert train_func_names["per_sample_transform"] == "per_sample_transform" - assert val_func_names["per_sample_transform"] == "val_per_sample_transform" - assert test_func_names["per_sample_transform"] == "per_sample_transform" - assert predict_func_names["per_sample_transform"] == "per_sample_transform" - - # collate - assert train_func_names["collate"] == "collate" - assert val_func_names["collate"] == "collate" - assert test_func_names["collate"] == "test_collate" - assert predict_func_names["collate"] == "collate" - - # per_sample_transform_on_device - assert train_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" - assert val_func_names["per_sample_transform_on_device"] == "val_per_sample_transform_on_device" - assert test_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" - assert predict_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" - - # per_batch_transform_on_device - assert train_func_names["per_batch_transform_on_device"] == "train_per_batch_transform_on_device" - assert val_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" - assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device" - assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" - - train_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TRAINING) - val_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.VALIDATING) - test_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TESTING) - predict_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.PREDICTING) - - assert train_worker_input_transform_processor.per_sample_transform.func == input_transform.per_sample_transform - assert train_worker_input_transform_processor.collate_fn.func == input_transform.collate - assert train_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform - - assert val_worker_input_transform_processor.per_sample_transform.func == input_transform.val_per_sample_transform - assert val_worker_input_transform_processor.collate_fn.func == DataPipeline._identity - assert val_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform - - assert test_worker_input_transform_processor.per_sample_transform.func == input_transform.per_sample_transform - assert test_worker_input_transform_processor.collate_fn.func == input_transform.test_collate - assert test_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform - - assert predict_worker_input_transform_processor.per_sample_transform.func == input_transform.per_sample_transform - assert predict_worker_input_transform_processor.collate_fn.func == input_transform.collate - assert predict_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform - - -class CustomInputTransform(DefaultInputTransform): - def train_per_sample_transform(self, *_, **__): - pass - - def train_per_batch_transform_on_device(self, *_, **__): - pass - - def test_per_sample_transform(self, *_, **__): - pass - - def test_per_batch_transform(self, *_, **__): - pass - - def test_per_sample_transform_on_device(self, *_, **__): - pass - - def test_per_batch_transform_on_device(self, *_, **__): - pass - - def val_per_batch_transform(self, *_, **__): - pass - - def val_per_sample_transform_on_device(self, *_, **__): - pass - - def predict_per_sample_transform(self, *_, **__): - pass - - def predict_per_sample_transform_on_device(self, *_, **__): - pass - - def predict_per_batch_transform_on_device(self, *_, **__): - pass - - -def test_data_pipeline_predict_worker_input_transform_processor_and_device_input_transform_processor(): - - input_transform = CustomInputTransform() - data_pipeline = DataPipeline(input_transform=input_transform) - - data_pipeline.worker_input_transform_processor(RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="are mutually exclusive"): - data_pipeline.worker_input_transform_processor(RunningStage.VALIDATING) - with pytest.raises(MisconfigurationException, match="are mutually exclusive"): - data_pipeline.worker_input_transform_processor(RunningStage.TESTING) - data_pipeline.worker_input_transform_processor(RunningStage.PREDICTING) - - -def test_detach_input_transform_from_model(tmpdir): - class CustomModel(Task): - def __init__(self, output_transform: Optional[OutputTransform] = None): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - self._output_transform = output_transform - - def train_dataloader(self) -> Any: - return DataLoader(DummyDataset()) - - input_transform = CustomInputTransform() - data_pipeline = DataPipeline(input_transform=input_transform) - model = CustomModel() - model.data_pipeline = data_pipeline - - assert model.train_dataloader().collate_fn == default_collate - assert model.transfer_batch_to_device.__self__ == model - model.on_train_dataloader() - assert isinstance(model.train_dataloader().collate_fn, _InputTransformProcessor) - assert isinstance(model.transfer_batch_to_device, _StageOrchestrator) - model.on_fit_end() - assert model.transfer_batch_to_device.__self__ == model - assert model.train_dataloader().collate_fn == default_collate - - -class TestInputTransform(DefaultInputTransform): - def train_per_sample_transform(self, *_, **__): - pass - - def train_per_batch_transform_on_device(self, *_, **__): - pass - - def test_per_sample_transform(self, *_, **__): - pass - - def test_per_sample_transform_on_device(self, *_, **__): - pass - - def test_per_batch_transform_on_device(self, *_, **__): - pass - - def val_per_sample_transform_on_device(self, *_, **__): - pass - - def predict_per_sample_transform(self, *_, **__): - pass - - def predict_per_sample_transform_on_device(self, *_, **__): - pass - - def predict_per_batch_transform_on_device(self, *_, **__): - pass - - -def test_attaching_datapipeline_to_model(tmpdir): - class SubInputTransform(DefaultInputTransform): - pass - - input_transform = SubInputTransform() - data_pipeline = DataPipeline(input_transform=input_transform) - - class CustomModel(Task): - def __init__(self): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - self._output_transform = OutputTransform() - - def training_step(self, batch: Any, batch_idx: int) -> Any: - pass - - def validation_step(self, batch: Any, batch_idx: int) -> Any: - pass - - def test_step(self, batch: Any, batch_idx: int) -> Any: - pass - - def train_dataloader(self) -> Any: - return DataLoader(DummyDataset()) - - def val_dataloader(self) -> Any: - return DataLoader(DummyDataset()) - - def test_dataloader(self) -> Any: - return DataLoader(DummyDataset()) - - def predict_dataloader(self) -> Any: - return DataLoader(DummyDataset()) - - class TestModel(CustomModel): - - stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] - on_train_start_called = False - on_val_start_called = False - on_test_start_called = False - on_predict_start_called = False - - def on_fit_start(self): - self._saved_predict_step = self.predict_step - - @staticmethod - def _compare_pre_processor(p1, p2): - assert p1.per_sample_transform.func == p2.per_sample_transform.func - assert p1.collate_fn.func == p2.collate_fn.func - assert p1.per_batch_transform.func == p2.per_batch_transform.func - +def test_is_overridden_recursive(tmpdir): + class TestInputTransform(InputTransform): @staticmethod - def _assert_stage_orchestrator_state( - stage_mapping: Dict, current_running_stage: RunningStage, cls=_InputTransformProcessor - ): - assert isinstance(stage_mapping[current_running_stage], cls) - assert stage_mapping[current_running_stage] - - def on_train_dataloader(self) -> None: - current_running_stage = RunningStage.TRAINING - self.on_train_dataloader_called = True - collate_fn = self.train_dataloader().collate_fn # noqa F811 - assert collate_fn == default_collate - assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) - super().on_train_dataloader() - collate_fn = self.train_dataloader().collate_fn # noqa F811 - assert collate_fn.stage == current_running_stage - self._compare_pre_processor( - collate_fn, self.data_pipeline.worker_input_transform_processor(current_running_stage) - ) - assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) - self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) - - def on_val_dataloader(self) -> None: - current_running_stage = RunningStage.VALIDATING - self.on_val_dataloader_called = True - collate_fn = self.val_dataloader().collate_fn # noqa F811 - assert collate_fn == default_collate - assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) - super().on_val_dataloader() - collate_fn = self.val_dataloader().collate_fn # noqa F811 - assert collate_fn.stage == current_running_stage - self._compare_pre_processor( - collate_fn, self.data_pipeline.worker_input_transform_processor(current_running_stage) - ) - assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) - self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) - - def on_test_dataloader(self) -> None: - current_running_stage = RunningStage.TESTING - self.on_test_dataloader_called = True - collate_fn = self.test_dataloader().collate_fn # noqa F811 - assert collate_fn == default_collate - assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) - super().on_test_dataloader() - collate_fn = self.test_dataloader().collate_fn # noqa F811 - assert collate_fn.stage == current_running_stage - self._compare_pre_processor( - collate_fn, self.data_pipeline.worker_input_transform_processor(current_running_stage) - ) - assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) - self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) - - def on_predict_dataloader(self) -> None: - current_running_stage = RunningStage.PREDICTING - self.on_predict_dataloader_called = True - collate_fn = self.predict_dataloader().collate_fn # noqa F811 - assert collate_fn == default_collate - assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) - assert self.predict_step == self._saved_predict_step - super().on_predict_dataloader() - collate_fn = self.predict_dataloader().collate_fn # noqa F811 - assert collate_fn.stage == current_running_stage - self._compare_pre_processor( - collate_fn, self.data_pipeline.worker_input_transform_processor(current_running_stage) - ) - assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) - assert isinstance(self.predict_step, _StageOrchestrator) - self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) - self._assert_stage_orchestrator_state( - self.predict_step._stage_mapping, current_running_stage, cls=_OutputTransformProcessor - ) - - def on_fit_end(self) -> None: - super().on_fit_end() - assert self.train_dataloader().collate_fn == default_collate - assert self.val_dataloader().collate_fn == default_collate - assert self.test_dataloader().collate_fn == default_collate - assert self.predict_dataloader().collate_fn == default_collate - assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) - assert self.predict_step == self._saved_predict_step - - model = TestModel() - model.data_pipeline = data_pipeline - trainer = Trainer(fast_dev_run=True) - trainer.fit(model) - trainer.test(model) - trainer.predict(model) - - assert model.on_train_dataloader_called - assert model.on_val_dataloader_called - assert model.on_test_dataloader_called - assert model.on_predict_dataloader_called - - -def test_stage_orchestrator_state_attach_detach(tmpdir): - - model = CustomModel() - input_transform = TestInputTransform() - - _original_predict_step = model.predict_step - - class CustomDataPipeline(DataPipeline): - def _attach_output_transform_to_model( - self, model: "Task", _output_transform_processor: _OutputTransformProcessor - ) -> "Task": - model.predict_step = self._model_predict_step_wrapper( - model.predict_step, _output_transform_processor, model - ) - return model - - data_pipeline = CustomDataPipeline(input_transform=input_transform) - _output_transform_processor = data_pipeline._create_output_transform_processor(RunningStage.PREDICTING) - data_pipeline._attach_output_transform_to_model(model, _output_transform_processor) - assert model.predict_step._original == _original_predict_step - assert model.predict_step._stage_mapping[RunningStage.PREDICTING] == _output_transform_processor - data_pipeline._detach_output_transform_from_model(model) - assert model.predict_step == _original_predict_step - - -class CustomModel(Task): - def __init__(self): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - - def training_step(self, batch, batch_idx): - assert batch is None - - def validation_step(self, batch, batch_idx): - if isinstance(batch, list): - batch = batch[0] - assert batch is False - - def test_step(self, batch, batch_idx): - assert len(batch) == 2 - assert batch[0].shape == torch.Size([2, 1]) - - def predict_step(self, batch, batch_idx, dataloader_idx=None): - assert batch[0][0] == "a" - assert batch[0][1] == "a" - assert batch[1][0] == "b" - assert batch[1][1] == "b" - return tensor([0, 0, 0]) + def custom_transform(x): + return x + def collate(self): + return self.custom_transform -def test_is_overridden_recursive(tmpdir): - class TestInputTransform(DefaultInputTransform): - def collate(self, *_): - pass - - def val_collate(self, *_): - pass + def val_collate(self): + return self.custom_transform - input_transform = TestInputTransform() + input_transform = TestInputTransform(RunningStage.TRAINING) assert DataPipeline._is_overridden_recursive("collate", input_transform, InputTransform, prefix="val") assert DataPipeline._is_overridden_recursive("collate", input_transform, InputTransform, prefix="train") assert not DataPipeline._is_overridden_recursive( @@ -515,114 +87,3 @@ def val_collate(self, *_): assert not DataPipeline._is_overridden_recursive("per_batch_transform_on_device", input_transform, InputTransform) with pytest.raises(MisconfigurationException, match="This function doesn't belong to the parent class"): assert not DataPipeline._is_overridden_recursive("chocolate", input_transform, InputTransform) - - -def test_input_transform_transforms(tmpdir): - """This test makes sure that when a input_transform is being provided transforms as dictionaries, checking is - done properly, and collate_in_worker_from_transform is properly extracted.""" - - with pytest.raises(MisconfigurationException, match="Transform should be a dict."): - DefaultInputTransform(train_transform="choco") - - with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): - DefaultInputTransform(train_transform={"choco": None}) - - input_transform = DefaultInputTransform(train_transform={"per_sample_transform": torch.nn.Linear(1, 1)}) - # keep is None - assert input_transform._train_collate_in_worker_from_transform is True - assert input_transform._val_collate_in_worker_from_transform is None - assert input_transform._test_collate_in_worker_from_transform is None - assert input_transform._predict_collate_in_worker_from_transform is None - - with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): - input_transform = DefaultInputTransform( - train_transform={ - "per_batch_transform": torch.nn.Linear(1, 1), - "per_sample_transform_on_device": torch.nn.Linear(1, 1), - } - ) - - input_transform = DefaultInputTransform( - train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, - predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}, - ) - # keep is None - assert input_transform._train_collate_in_worker_from_transform is True - assert input_transform._val_collate_in_worker_from_transform is None - assert input_transform._test_collate_in_worker_from_transform is None - assert input_transform._predict_collate_in_worker_from_transform is False - - train_input_transform_processor = DataPipeline(input_transform=input_transform).worker_input_transform_processor( - RunningStage.TRAINING - ) - val_input_transform_processor = DataPipeline(input_transform=input_transform).worker_input_transform_processor( - RunningStage.VALIDATING - ) - test_input_transform_processor = DataPipeline(input_transform=input_transform).worker_input_transform_processor( - RunningStage.TESTING - ) - predict_input_transform_processor = DataPipeline(input_transform=input_transform).worker_input_transform_processor( - RunningStage.PREDICTING - ) - - assert train_input_transform_processor.collate_fn.func == input_transform.collate - assert val_input_transform_processor.collate_fn.func == input_transform.collate - assert test_input_transform_processor.collate_fn.func == input_transform.collate - assert predict_input_transform_processor.collate_fn.func == DataPipeline._identity - - class CustomInputTransform(DefaultInputTransform): - def per_sample_transform_on_device(self, sample: Any) -> Any: - return super().per_sample_transform_on_device(sample) - - def per_batch_transform(self, batch: Any) -> Any: - return super().per_batch_transform(batch) - - input_transform = CustomInputTransform( - train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, - predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}, - ) - # keep is None - assert input_transform._train_collate_in_worker_from_transform is True - assert input_transform._val_collate_in_worker_from_transform is None - assert input_transform._test_collate_in_worker_from_transform is None - assert input_transform._predict_collate_in_worker_from_transform is False - - data_pipeline = DataPipeline(input_transform=input_transform) - - train_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): - val_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.VALIDATING) - with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): - test_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TESTING) - predict_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.PREDICTING) - - assert train_input_transform_processor.collate_fn.func == input_transform.collate - assert predict_input_transform_processor.collate_fn.func == DataPipeline._identity - - -class CustomInputTransformHyperparameters(DefaultInputTransform): - def __init__(self, token: str, *args, **kwargs): - self.token = token - super().__init__(*args, **kwargs) - - @classmethod - def load_from_state_dict(cls, state_dict: Dict[str, Any]): - return cls(state_dict["token"]) - - def state_dict(self) -> Dict[str, Any]: - return {"token": self.token} - - -def local_fn(x): - return x - - -def test_save_hyperparemeters(tmpdir): - - kwargs = {"train_transform": {"per_sample_transform": local_fn}} - input_transform = CustomInputTransformHyperparameters("token", **kwargs) - state_dict = input_transform.state_dict() - torch.save(state_dict, os.path.join(tmpdir, "state_dict.pt")) - state_dict = torch.load(os.path.join(tmpdir, "state_dict.pt")) - input_transform = CustomInputTransformHyperparameters.load_from_state_dict(state_dict) - assert isinstance(input_transform, CustomInputTransformHyperparameters) diff --git a/tests/core/data/test_input_transform.py b/tests/core/data/test_input_transform.py index 385d396c8f..0509f8b3b5 100644 --- a/tests/core/data/test_input_transform.py +++ b/tests/core/data/test_input_transform.py @@ -148,3 +148,54 @@ def input_per_batch_transform(self) -> Callable: return super().input_per_batch_transform MyTransform(1, running_stage=RunningStage.TRAINING) + + +class CustomInputTransform(InputTransform): + @staticmethod + def custom_transform(x): + return x + + def train_per_sample_transform(self): + return self.custom_transform + + def train_per_batch_transform_on_device(self, *_, **__): + return self.custom_transform + + def test_per_sample_transform(self, *_, **__): + return self.custom_transform + + def test_per_batch_transform(self, *_, **__): + return self.custom_transform + + def test_per_sample_transform_on_device(self, *_, **__): + return self.custom_transform + + def test_per_batch_transform_on_device(self, *_, **__): + return self.custom_transform + + def val_per_batch_transform(self, *_, **__): + return self.custom_transform + + def val_per_sample_transform_on_device(self, *_, **__): + return self.custom_transform + + def predict_per_sample_transform(self, *_, **__): + return self.custom_transform + + def predict_per_sample_transform_on_device(self, *_, **__): + return self.custom_transform + + def predict_per_batch_transform_on_device(self, *_, **__): + return self.custom_transform + + +def test_check_transforms(): + + input_transform = CustomInputTransform + + input_transform(RunningStage.TRAINING) + with pytest.raises(MisconfigurationException, match="are mutually exclusive"): + input_transform(RunningStage.VALIDATING) + with pytest.raises(MisconfigurationException, match="are mutually exclusive"): + input_transform(RunningStage.TESTING) + input_transform(RunningStage.PREDICTING) diff --git a/tests/core/data/test_sampler.py b/tests/core/data/test_sampler.py deleted file mode 100644 index fd114d64f2..0000000000 --- a/tests/core/data/test_sampler.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest import mock - -from flash import DataModule - - -@mock.patch("flash.core.data.data_module.DataLoader") -def test_dataloaders_with_sampler(mock_dataloader): - train_ds = val_ds = test_ds = "dataset" - mock_sampler = mock.MagicMock() - dm = DataModule(train_ds, val_ds, test_ds, num_workers=0, sampler=mock_sampler) - assert dm.sampler is mock_sampler - dl = dm.train_dataloader() - kwargs = mock_dataloader.call_args[1] - assert "sampler" in kwargs - assert kwargs["sampler"] is mock_sampler.return_value - for dl in [dm.val_dataloader(), dm.test_dataloader()]: - kwargs = mock_dataloader.call_args[1] - assert "sampler" not in kwargs diff --git a/tests/core/data/test_serialization.py b/tests/core/data/test_serialization.py index 94999b37dd..58d7c72f3b 100644 --- a/tests/core/data/test_serialization.py +++ b/tests/core/data/test_serialization.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import torch @@ -19,8 +18,6 @@ from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data.dataloader import DataLoader -from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.model import Task @@ -29,12 +26,6 @@ def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) -class CustomInputTransform(DefaultInputTransform): - @classmethod - def load_data(cls, data): - return data - - def test_serialization_data_pipeline(tmpdir): model = CustomModel() @@ -50,22 +41,12 @@ def test_serialization_data_pipeline(tmpdir): loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline - model.data_pipeline = DataPipeline(input_transform=CustomInputTransform()) - assert isinstance(model.input_transform, CustomInputTransform) - trainer.fit(model, dummy_data) assert model.data_pipeline - assert isinstance(model.input_transform, CustomInputTransform) trainer.save_checkpoint(checkpoint_file) - def fn(*args, **kwargs): - return "0.0.2" - - CustomInputTransform.version = fn - loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline - assert isinstance(loaded_model.input_transform, CustomInputTransform) for file in os.listdir(tmpdir): if file.endswith(".ckpt"): os.remove(os.path.join(tmpdir, file)) diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 7c77f69075..0b4b37584c 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -13,7 +13,8 @@ # limitations under the License. import torch -from flash import DataModule +from flash import DataKeys, DataModule, RunningStage +from flash.core.data.data_module import DatasetInput # ======== Mock functions ======== @@ -30,26 +31,30 @@ def __len__(self) -> int: def test_init(): - train_ds, val_ds, test_ds = DummyDataset(), DummyDataset(), DummyDataset() - DataModule(train_ds) - DataModule(train_ds, val_ds) - DataModule(train_ds, val_ds, test_ds) - assert DataModule().data_pipeline + train_input = DatasetInput(RunningStage.TRAINING, DummyDataset()) + val_input = DatasetInput(RunningStage.VALIDATING, DummyDataset()) + test_input = DatasetInput(RunningStage.TESTING, DummyDataset()) + DataModule(train_input, batch_size=1) + DataModule(train_input, val_input, batch_size=1) + DataModule(train_input, val_input, test_input, batch_size=1) + assert DataModule(batch_size=1).data_pipeline def test_dataloaders(): - train_ds, val_ds, test_ds = DummyDataset(), DummyDataset(), DummyDataset() - dm = DataModule(train_ds, val_ds, test_ds, num_workers=0) + train_input = DatasetInput(RunningStage.TRAINING, DummyDataset()) + val_input = DatasetInput(RunningStage.VALIDATING, DummyDataset()) + test_input = DatasetInput(RunningStage.TESTING, DummyDataset()) + dm = DataModule(train_input, val_input, test_input, num_workers=0, batch_size=1) for dl in [ dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader(), ]: - x, y = next(iter(dl)) - assert x.shape == (4, 1, 28, 28) + x = next(iter(dl))[DataKeys.INPUT] + assert x.shape == (1, 1, 28, 28) def test_cpu_count_none(): - train_ds = DummyDataset() - dm = DataModule(train_ds, num_workers=None) + train_input = DatasetInput(RunningStage.TRAINING, DummyDataset()) + dm = DataModule(train_input, num_workers=None, batch_size=1) assert dm.num_workers == 0 diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index 7fcfc341ea..4c51fc9e4a 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -27,8 +27,6 @@ from flash import Trainer from flash.core.classification import ClassificationTask -from flash.core.utilities.stages import RunningStage -from tests.helpers.boring_model import BoringModel class DummyDataset(torch.utils.data.Dataset): @@ -140,60 +138,3 @@ def test_from_argparse_args(): trainer = Trainer.from_argparse_args(args) assert trainer.max_epochs == 200 assert isinstance(trainer, Trainer) - - -@pytest.mark.parametrize("stage", ["train", "val", "test"]) -def test_trainer_request_dataloaders_legacy(stage): - """Test to ensure that ``request_dataloaders`` can take the legacy PL ordering of arguments. - - legacy: (model, stage) - """ - - class TestTrainer(Trainer): - recorded_on_dataloader_calls = {} - - def on_train_dataloader(self) -> None: - self.recorded_on_dataloader_calls["train"] = True - - def on_val_dataloader(self) -> None: - self.recorded_on_dataloader_calls["val"] = True - - def on_test_dataloader(self) -> None: - self.recorded_on_dataloader_calls["test"] = True - - model = BoringModel() - trainer = TestTrainer() - - trainer.request_dataloader(model, stage) - assert trainer.recorded_on_dataloader_calls[stage] - - -@pytest.mark.skip(reason="TODO: test can only be enabled once Lightning 1.5 is released.") -@pytest.mark.parametrize("stage", [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING]) -def test_trainer_request_dataloaders(stage): - """Test to ensure that ``request_dataloaders`` can take a combination of arguments, for PL 1.5 and later. - - (stage, model) -> calls module on_dataloader hook (stage, model=model) -> calls module on_dataloader hook - """ - - class TestModel(BoringModel): - recorded_on_dataloader_calls = {} - - def on_train_dataloader(self) -> None: - self.recorded_on_dataloader_calls[RunningStage.TRAINING] = True - - def on_val_dataloader(self) -> None: - self.recorded_on_dataloader_calls[RunningStage.VALIDATING] = True - - def on_test_dataloader(self) -> None: - self.recorded_on_dataloader_calls[RunningStage.TESTING] = True - - trainer = Trainer() - - model = TestModel() - trainer.request_dataloader(stage, model) - assert model.recorded_on_dataloader_calls[stage] - - model = TestModel() - trainer.request_dataloader(stage, model=model) - assert model.recorded_on_dataloader_calls[stage] diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index d5e4c6c28a..96ed210d3b 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -17,7 +17,7 @@ from flash import RunningStage, Trainer from flash.__main__ import main -from flash.core.data.new_data_module import DataModule +from flash.core.data.data_module import DataModule from flash.core.utilities.imports import _GRAPH_AVAILABLE from flash.graph.classification import GraphClassifier from flash.graph.classification.input import GraphClassificationDatasetInput diff --git a/tests/graph/embedding/test_model.py b/tests/graph/embedding/test_model.py index 844faed936..7cfeaba309 100644 --- a/tests/graph/embedding/test_model.py +++ b/tests/graph/embedding/test_model.py @@ -15,7 +15,7 @@ import torch from flash import RunningStage, Trainer -from flash.core.data.new_data_module import DataModule +from flash.core.data.data_module import DataModule from flash.core.utilities.imports import _GRAPH_AVAILABLE from flash.graph.classification.input import GraphClassificationDatasetInput from flash.graph.classification.input_transform import GraphClassificationInputTransform diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index b12141f832..11ad1cb29d 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -138,9 +138,7 @@ def test_jit(tmpdir, jitter, args): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = ImageClassifier(2) - # TODO: Currently only servable once a input_transform has been attached - model._input_transform = ImageClassificationInputTransform(RunningStage.SERVING) - model._deserializer = ImageDeserializer() + model._deserializer = ImageDeserializer(transform=ImageClassificationInputTransform(RunningStage.SERVING)) model.eval() model.serve() diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 100180d95f..b6630c0e38 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -149,9 +149,10 @@ def test_jit(tmpdir, jitter, args): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SemanticSegmentation(2) - # TODO: Currently only servable once a input_transform has been attached - model._input_transform = SemanticSegmentationInputTransform(RunningStage.SERVING) - model._deserializer = SemanticSegmentationDeserializer() + + model._deserializer = SemanticSegmentationDeserializer( + transform=SemanticSegmentationInputTransform(RunningStage.SERVING) + ) model.eval() model.serve() diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index d8c27dae1a..9a73dd1476 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -19,6 +19,7 @@ import torch from pytorch_lightning import Trainer +from flash import InputTransform, RunningStage from flash.__main__ import main from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE @@ -109,10 +110,9 @@ def test_serve(): batch_size=1, ) model = TabularClassifier.from_data(datamodule) - # TODO: Currently only servable once a input_transform has been attached - model._input_transform = datamodule.input_transform - model._input_transform._state = datamodule.train_dataset._state - model._deserializer = TabularDeserializer() + + model._deserializer = TabularDeserializer(transform=InputTransform(RunningStage.SERVING)) + model._deserializer._state = datamodule.train_dataset._state model.eval() model.serve() diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 37d8a0cfbf..148fc5f514 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -79,9 +79,7 @@ def test_jit(tmpdir): def test_serve(): model = TextClassifier(2, TEST_BACKBONE) - # TODO: Currently only servable once an input_transform has been attached - model._input_transform = TransformersInputTransform(RunningStage.SERVING) - model._deserializer = TextDeserializer() + model._deserializer = TextDeserializer(transform=TransformersInputTransform(RunningStage.SERVING)) model.eval() model.serve() diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index 911037920d..d57843815d 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -79,9 +79,7 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SummarizationTask(TEST_BACKBONE) - # TODO: Currently only servable once a input_transform has been attached - model._input_transform = TransformersInputTransform(RunningStage.SERVING) - model._deserializer = TextDeserializer() + model._deserializer = TextDeserializer(transform=TransformersInputTransform(RunningStage.SERVING)) model._output_transform = Seq2SeqOutputTransform() model.eval() model.serve() diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index 7639a02223..05ff3cb5da 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -79,9 +79,7 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TranslationTask(TEST_BACKBONE) - # TODO: Currently only servable once a input_transform and output_transform have been attached - model._input_transform = TransformersInputTransform(RunningStage.SERVING) - model._deserializer = TextDeserializer() + model._deserializer = TextDeserializer(transform=TransformersInputTransform(RunningStage.SERVING)) model._output_transform = Seq2SeqOutputTransform() model.eval()