diff --git a/docs/source/custom_task.rst b/docs/source/custom_task.rst index eeb3ef7a1d..3cc818a94f 100644 --- a/docs/source/custom_task.rst +++ b/docs/source/custom_task.rst @@ -6,7 +6,8 @@ along with a custom :class:`~flash.data.data_module.DataModule`. The tutorial objective is to create a ``RegressionTask`` to learn to predict if someone has ``diabetes`` or not. -The diabetes (stored as a numpy dataset). +We will use ``scikit-learn`` `Diabetes dataset `__. +which is stored as numpy arrays. .. note:: @@ -72,8 +73,11 @@ override the ``__init__`` and ``forward`` methods. .. note:: - Lightning Flash provides an API to register models within a store. - Check out :ref:`registry`. + Lightning Flash provides registries. + Registries are Flash internal key-value database to store a mapping between a name and a function. + In simple words, they are just advanced dictionary storing a function from a key string. + They are useful to store list of backbones and make them available for a :class:`~flash.core.model.Task`. + Check out to learn more :ref:`registry`. Where is the training step? @@ -99,17 +103,20 @@ Example:: x, y = batch y_hat = self(x) # compute the logs, loss and metrics as an output dictionary + ... return output 3.a The DataModule API ---------------------- -First, let's first design the user-facing API. -We are going to create a ``NumpyDataModule`` class subclassing :class:`~flash.data.data_module.DataModule`. -This ``NumpyDataModule`` will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate +Now that we have defined our ``RegressionTask``, we need to load our data. +We will define a custom ``NumpyDataModule`` class subclassing :class:`~flash.data.data_module.DataModule`. +This ``NumpyDataModule`` class will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate :class:`~flash.data.data_module.DataModule` from x, y numpy arrays. +Here is how it would look like: + Example:: x, y = ... @@ -144,6 +151,8 @@ Example:: x, y, test_size=.20, random_state=0) # Make sure to call ``from_load_data_inputs``. + # The ``train_load_data_input`` value will be given to ``Preprocess`` + # ``train_load_data`` function. dm = cls.from_load_data_inputs( train_load_data_input=(x_train, y_train), test_load_data_input=(x_test, y_test), @@ -166,22 +175,19 @@ Example:: 3.b The Preprocess API ---------------------- -.. note:: - - As new concepts are being introduced, we strongly encourage the reader to click on :class:`~flash.data.process.Preprocess` - before going further with the tutorial. +A :class:`~flash.data.process.Preprocess` object provides a series of hooks that can be overridden with custom data processing logic. +It allows the user much more granular control over their data processing flow. .. note:: Why introducing :class:`~flash.data.process.Preprocess` ? - A :class:`~flash.data.process.Preprocess` object provides a series of hooks that can be overridden with custom data processing logic. - The user has much more granular control over their data processing flow. - The :class:`~flash.data.process.Preprocess` object reduces the engineering overhead to make inference on raw data or to deploy the model in production environnement compared to traditional `Dataset `_. + You can override ``predict_{hook_name}`` hooks to handle data processing logic specific for inference. + Example:: import torch @@ -210,10 +216,14 @@ Example:: return torch.from_numpy(sample).float() +You now have a new customized Flash Task! Congratulations ! + +You can fit, finetune, validate and predict directly with those objects. + 4. Fitting ---------- -For this task, we will be fitting the ``RegressionTask`` Task on ``scikit-learn`` `Diabetes +For this task, here is how to fit the ``RegressionTask`` Task on ``scikit-learn`` `Diabetes dataset `__. Like any Flash Task, we can fit our model using the ``flash.Trainer`` by @@ -226,7 +236,7 @@ supplying the task itself, and the associated data: model = RegressionTask(num_inputs=datamodule.num_inputs) trainer = flash.Trainer(max_epochs=1000) - trainer.fit(model, data) + trainer.fit(model, datamodule=datamodule) 5. Predicting diff --git a/docs/source/general/callback.rst b/docs/source/general/callback.rst index b0a6cb588a..74e440216a 100644 --- a/docs/source/general/callback.rst +++ b/docs/source/general/callback.rst @@ -16,6 +16,13 @@ Flash and Lightning have a callback system to execute callbacks when needed. Callbacks should capture any NON-ESSENTIAL logic that is NOT required for your lightning module to run. +Same as PyTorch Lightning, Callbacks can be provided directly to the Trainer. + +Example:: + + trainer = Trainer(callbacks=[MyCustomCallback()]) + + ******************* Available Callbacks ******************* @@ -27,10 +34,10 @@ _______________ .. autoclass:: flash.data.callback.BaseDataFetcher :members: enable -BaseViz -_______ +BaseVisualization +_________________ -.. autoclass:: flash.data.base_viz.BaseViz +.. autoclass:: flash.data.base_viz.BaseVisualization :members: diff --git a/docs/source/general/registry.rst b/docs/source/general/registry.rst index ac499d5f58..0e43d045d6 100644 --- a/docs/source/general/registry.rst +++ b/docs/source/general/registry.rst @@ -13,54 +13,72 @@ Registries are Flash internal key-value database to store a mapping between a na In simple words, they are just advanced dictionary storing a function from a key string. Registries help organize code and make the functions accessible all across the ``Flash`` codebase. -Each Flash ``Task`` can have several registries as static attributes. -It enables to quickly experiment with your backbone functions or use our long list of available backbones. +Each Flash :class:`~flash.core.model.Task` can have several registries as static attributes. + +Currently, Flash uses internally registries only for backbones, but more components will be added. + +1. Imports +__________ + +Example:: + + from flash.core.registry import FlashRegistry + +2. Init a Registry +__________________ + +It is good practice to associate one or multiple registry to a Task as follow: Example:: from flash.vision import ImageClassifier from flash.core.registry import FlashRegistry + # creating a custom ``ImageClassifier`` with its own registry class MyImageClassifier(ImageClassifier): backbones = FlashRegistry("backbones") - @MyImageClassifier.backbones(name="username/my_backbone") - def fn(): - # Create backbone and backbone output dimension (`num_features`) - return backbone, num_features - - # The new key should be listed in available backbones - print(MyImageClassifier.available_backbones()) - # out: ["username/my_backbone"] - - # Create a model with your backbone ! - model = MyImageClassifier(backbone="username/my_backbone") +3. Adding new functions +_______________________ Your custom functions can be registered within a :class:`~flash.core.registry.FlashRegistry` as a decorator or directly. Example:: - from functools import partial - - # Create a registry - backbones = FlashRegistry("backbones") - # Option 1: Used with partial. def fn(backbone: str): # Create backbone and backbone output dimension (`num_features`) return backbone, num_features # HINT 1: Use `from functools import partial` if you want to store some arguments. - backbones(fn=partial(fn, backbone="my_backbone"), name="username/my_backbone") + MyImageClassifier.backbones(fn=partial(fn, backbone="my_backbone"), name="username/my_backbone") # Option 2: Using decorator. - @backbones(name="username/my_backbone") + @MyImageClassifier.backbones(name="username/my_backbone") def fn(): # Create backbone and backbone output dimension (`num_features`) return backbone, num_features +4. Accessing registered functions +_________________________________ + +You can now access your function from your task! + +Example:: + + # 3.b Optional: List available backbones + print(MyImageClassifier.available_backbones()) + # out: ["username/my_backbone"] + + # 4. Build the model + model = MyImageClassifier(backbone="username/my_backbone", num_classes=2) + + +5. Pre-registered ones +______________________ + Flash provides already populated registries containing lot of available backbones. Example:: @@ -73,8 +91,6 @@ Example:: """ - - ************** Flash Registry ************** diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index 7efe975e43..403a5fe081 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -7,7 +7,7 @@ from flash.data.utils import _PREPROCESS_FUNCS -class BaseViz(BaseDataFetcher): +class BaseVisualization(BaseDataFetcher): """ This Base Class is used to create visualization tool on top of :class:`~flash.data.process.Preprocess` hooks. @@ -16,9 +16,9 @@ class BaseViz(BaseDataFetcher): Example:: from flash.vision import ImageClassificationData - from flash.data.base_viz import BaseViz + from flash.data.base_viz import BaseVisualization - class CustomBaseViz(BaseViz): + class CustomBaseVisualization(BaseVisualization): def show_load_sample(self, samples: List[Any], running_stage): # plot samples @@ -42,7 +42,7 @@ class CustomImageClassificationData(ImageClassificationData): @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - return CustomBaseViz(*args, **kwargs) + return CustomBaseVisualization(*args, **kwargs) dm = CustomImageClassificationData.from_folders( train_folder="./data/train", @@ -72,7 +72,7 @@ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: Example:: - class CustomBaseViz(BaseViz): + class CustomBaseVisualization(BaseVisualization): def show(self, batch: Dict[str, Any], running_stage: RunningStage): print(batch) @@ -103,7 +103,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None: """ for func_name in _PREPROCESS_FUNCS: hook_name = f"show_{func_name}" - if _is_overriden(hook_name, self, BaseViz): + if _is_overriden(hook_name, self, BaseVisualization): getattr(self, hook_name)(batch[func_name], running_stage) def show_load_sample(self, samples: List[Any], running_stage: RunningStage): diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 9a022465fb..890c0a6661 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -24,7 +24,7 @@ from torch.utils.data.dataset import Subset from flash.data.auto_dataset import AutoDataset -from flash.data.base_viz import BaseViz +from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess from flash.data.utils import _STAGES_PREFIX @@ -86,7 +86,7 @@ def __init__( self._preprocess: Optional[Preprocess] = None self._postprocess: Optional[Postprocess] = None - self._viz: Optional[BaseViz] = None + self._viz: Optional[BaseVisualization] = None self._data_fetcher: Optional[BaseDataFetcher] = None # this may also trigger data preloading @@ -113,11 +113,11 @@ def predict_dataset(self) -> Optional[Dataset]: return self._predict_ds @property - def viz(self) -> BaseViz: + def viz(self) -> BaseVisualization: return self._viz or DataModule.configure_data_fetcher() @viz.setter - def viz(self, viz: BaseViz) -> None: + def viz(self, viz: BaseVisualization) -> None: self._viz = viz @staticmethod @@ -159,7 +159,7 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None: except StopIteration: iter_dataloader = self._reset_iterator(stage) _ = next(iter_dataloader) - data_fetcher: BaseViz = self.data_fetcher + data_fetcher: BaseVisualization = self.data_fetcher data_fetcher._show(stage) if reset: self.viz.batches[stage] = {} diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 09d66c1094..b0d96a5252 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -22,7 +22,7 @@ from pytorch_lightning.trainer.states import RunningStage from torch import tensor -from flash.data.base_viz import BaseViz +from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.utils import _STAGES_PREFIX @@ -92,7 +92,7 @@ def test_base_viz(tmpdir): _rand_image().save(tmpdir / "b" / "a_1.png") _rand_image().save(tmpdir / "b" / "a_2.png") - class CustomBaseViz(BaseViz): + class CustomBaseVisualization(BaseVisualization): show_load_sample_called = False show_pre_tensor_transform_called = False @@ -130,8 +130,8 @@ def check_reset(self): class CustomImageClassificationData(ImageClassificationData): @staticmethod - def configure_data_fetcher(*args, **kwargs) -> CustomBaseViz: - return CustomBaseViz(*args, **kwargs) + def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: + return CustomBaseVisualization(*args, **kwargs) dm = CustomImageClassificationData.from_filepaths( train_filepaths=[tmpdir / "a", tmpdir / "b"],