From 14502b058bdd814c6983c5b5810d21403cc9472c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 14 Apr 2021 16:02:32 +0100 Subject: [PATCH] Fix `DataPipeline` resolution in `Task` (#212) * Initial commit * Fix docstrings * Fix broken tests * Populate default_preprocess in embedding model --- flash/core/classification.py | 3 +- flash/core/model.py | 124 ++++++++++++++++++++----------- flash/vision/embedding/model.py | 7 +- tests/core/test_model.py | 4 +- tests/data/test_data_pipeline.py | 20 +++-- tests/data/test_serialization.py | 4 +- 6 files changed, 105 insertions(+), 57 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 4340f404b5..f82de91c5f 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -28,5 +28,4 @@ def per_sample_transform(self, samples: Any) -> Any: class ClassificationTask(Task): def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._postprocess = ClassificationPostprocess() + super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs) diff --git a/flash/core/model.py b/flash/core/model.py index a0f71c547b..68d17173fe 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -14,7 +14,7 @@ import functools import inspect from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union import torch import torchmetrics @@ -59,7 +59,9 @@ class Task(LightningModule): loss_fn: Loss function for training optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. metrics: Metrics to compute for training and evaluation. - learning_rate: Learning rate to use for training, defaults to `5e-5` + learning_rate: Learning rate to use for training, defaults to `5e-5`. + default_preprocess: :class:`.Preprocess` to use as the default for this task. + default_postprocess: :class:`.Postprocess` to use as the default for this task. """ def __init__( @@ -69,6 +71,8 @@ def __init__( optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, + default_preprocess: Preprocess = None, + default_postprocess: Postprocess = None, ): super().__init__() if model is not None: @@ -80,9 +84,8 @@ def __init__( # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") - self._data_pipeline = None - self._preprocess = None - self._postprocess = None + self._preprocess = default_preprocess + self._postprocess = default_postprocess def step(self, batch: Any, batch_idx: int) -> Any: """ @@ -142,7 +145,9 @@ def predict( The post-processed model predictions """ running_stage = RunningStage.PREDICTING - data_pipeline = data_pipeline or self.data_pipeline + + data_pipeline = self.build_data_pipeline(data_pipeline) + x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)] x = data_pipeline.worker_preprocessor(running_stage)(x) x = self.transfer_batch_to_device(x, self.device) @@ -165,56 +170,91 @@ def configure_optimizers(self) -> torch.optim.Optimizer: def configure_finetune_callback(self) -> List[Callback]: return [] - @property - def preprocess(self) -> Optional[Preprocess]: - return getattr(self._data_pipeline, '_preprocess_pipeline', None) or self._preprocess + @staticmethod + def _resolve( + old_preprocess: Optional[Preprocess], + old_postprocess: Optional[Postprocess], + new_preprocess: Optional[Preprocess], + new_postprocess: Optional[Postprocess], + ) -> Tuple[Optional[Preprocess], Optional[Postprocess]]: + """Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not + None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise. - @preprocess.setter - def preprocess(self, preprocess: Preprocess) -> None: - self._preprocess = preprocess - self.data_pipeline = DataPipeline(preprocess, self.postprocess) + Args: + old_preprocess: :class:`.Preprocess` to be overridden. + old_postprocess: :class:`.Postprocess` to be overridden. + new_preprocess: :class:`.Preprocess` to override with. + new_postprocess: :class:`.Postprocess` to override with. - @property - def postprocess(self) -> Postprocess: - postprocess_cls = getattr(self, "postprocess_cls", None) - return ( - self._postprocess or (postprocess_cls() if postprocess_cls else None) - or getattr(self._data_pipeline, '_postprocess_pipeline', None) or Postprocess() - ) + Returns: + The resolved :class:`.Preprocess` and :class:`.Postprocess`. + """ + preprocess = old_preprocess + if new_preprocess is not None and type(new_preprocess) != Preprocess: + preprocess = new_preprocess - @postprocess.setter - def postprocess(self, postprocess: Postprocess) -> None: - self.data_pipeline = DataPipeline(self.preprocess, postprocess) - self._postprocess = postprocess + postprocess = old_postprocess + if new_postprocess is not None and type(new_postprocess) != Postprocess: + postprocess = new_postprocess - @property - def data_pipeline(self) -> Optional[DataPipeline]: - if self._data_pipeline is not None: - return self._data_pipeline + return preprocess, postprocess - elif self.preprocess is not None or self.postprocess is not None: - # use direct attributes here to avoid recursion with properties that also check the data_pipeline property - return DataPipeline(self.preprocess, self.postprocess) + def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]: + """Build a :class:`.DataPipeline` incorporating available :class:`.Preprocess` and :class:`.Postprocess` + objects. These will be overridden in the following resolution order (lowest priority first): - elif self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None: - return self.datamodule.data_pipeline + - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. + - :class:`.Task` defaults given to ``.Task.__init__``. + - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`. + - :class:`.DataPipeline` passed to this method. + + Args: + data_pipeline: Optional highest priority source of :class:`.Preprocess` and :class:`.Postprocess`. + + Returns: + The fully resolved :class:`.DataPipeline`. + """ + preprocess, postprocess = None, None + + # Datamodule + if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None: + preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) + postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) elif self.trainer is not None and hasattr( self.trainer, 'datamodule' ) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None: - return self.trainer.datamodule.data_pipeline + preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) + postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) + + # Defaults / task attributes + preprocess, postprocess = Task._resolve(preprocess, postprocess, self._preprocess, self._postprocess) - return self._data_pipeline + # Datapipeline + if data_pipeline is not None: + preprocess, postprocess = Task._resolve( + preprocess, + postprocess, + getattr(data_pipeline, '_preprocess_pipeline', None), + getattr(data_pipeline, '_postprocess_pipeline', None), + ) + + return DataPipeline(preprocess, postprocess) + + @property + def data_pipeline(self) -> DataPipeline: + """The current :class:`.DataPipeline`. If set, the new value will override the :class:`.Task` defaults. See + :py:meth:`~build_data_pipeline` for more details on the resolution order.""" + return self.build_data_pipeline() @data_pipeline.setter def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: - self._data_pipeline = data_pipeline - if data_pipeline is not None and getattr(data_pipeline, '_preprocess_pipeline', None) is not None: - self._preprocess = data_pipeline._preprocess_pipeline - - if data_pipeline is not None and getattr(data_pipeline, '_postprocess_pipeline', None) is not None: - if type(data_pipeline._postprocess_pipeline) != Postprocess: - self._postprocess = data_pipeline._postprocess_pipeline + self._preprocess, self._postprocess = Task._resolve( + self._preprocess, + self._postprocess, + getattr(data_pipeline, '_preprocess_pipeline', None), + getattr(data_pipeline, '_postprocess_pipeline', None), + ) def on_train_dataloader(self) -> None: if self.data_pipeline is not None: diff --git a/flash/vision/embedding/model.py b/flash/vision/embedding/model.py index 675a99071d..d1ddce5c84 100644 --- a/flash/vision/embedding/model.py +++ b/flash/vision/embedding/model.py @@ -49,10 +49,6 @@ class ImageEmbedder(Task): backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES - @property - def preprocess(self): - return ImageClassificationPreprocess(predict_transform=ImageClassificationData.default_val_transforms()) - def __init__( self, embedding_dim: Optional[int] = None, @@ -70,6 +66,9 @@ def __init__( optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, + default_preprocess=ImageClassificationPreprocess( + predict_transform=ImageClassificationData.default_val_transforms(), + ) ) self.save_hyperparameters() diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 9c6dcbe0d9..d0b0048b23 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -115,7 +115,7 @@ def test_task_datapipeline_save(tmpdir): task = ClassificationTask(model, F.nll_loss) # to check later - task.postprocess.test = True + task._postprocess.test = True # generate a checkpoint trainer = pl.Trainer( @@ -132,7 +132,7 @@ def test_task_datapipeline_save(tmpdir): # load from file task = ClassificationTask.load_from_checkpoint(path, model=model) - assert task.postprocess.test + assert task._postprocess.test @pytest.mark.parametrize( diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index a175b6387c..d1682c053c 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -71,8 +71,16 @@ class SubPostprocess(Postprocess): model = CustomModel(postprocess=Postprocess()) model.data_pipeline = data_pipeline - assert isinstance(model._preprocess, SubPreprocess if use_preprocess else Preprocess) - assert isinstance(model._postprocess, SubPostprocess if use_postprocess else Postprocess) + + if use_preprocess: + assert isinstance(model._preprocess, SubPreprocess) + else: + assert model._preprocess is None or isinstance(model._preprocess, Preprocess) + + if use_postprocess: + assert isinstance(model._postprocess, SubPostprocess) + else: + assert model._postprocess is None or isinstance(model._postprocess, Postprocess) def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): @@ -330,15 +338,17 @@ def predict_per_batch_transform_on_device(self, *_, **__): def test_attaching_datapipeline_to_model(tmpdir): - preprocess = Preprocess() + class SubPreprocess(Preprocess): + pass + + preprocess = SubPreprocess() data_pipeline = DataPipeline(preprocess) class CustomModel(Task): - _postprocess = Postprocess() - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + self._postprocess = Postprocess() def training_step(self, batch: Any, batch_idx: int) -> Any: pass diff --git a/tests/data/test_serialization.py b/tests/data/test_serialization.py index fda5cb7643..bc35fc0eb4 100644 --- a/tests/data/test_serialization.py +++ b/tests/data/test_serialization.py @@ -57,11 +57,11 @@ def test_serialization_data_pipeline(tmpdir): trainer.fit(model, dummy_data) assert model.data_pipeline - assert isinstance(model.preprocess, CustomPreprocess) + assert isinstance(model._preprocess, CustomPreprocess) trainer.save_checkpoint(checkpoint_file) loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline - assert isinstance(loaded_model.preprocess, CustomPreprocess) + assert isinstance(loaded_model._preprocess, CustomPreprocess) for file in os.listdir(tmpdir): if file.endswith('.ckpt'): os.remove(os.path.join(tmpdir, file))