From f5e3c498502c7d0d6a464569d40be650e9e6e94c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 14 May 2021 16:31:40 +0100 Subject: [PATCH 01/53] Initial commit --- docs/source/index.rst | 1 + docs/source/task_template.rst | 75 +++++++++++ flash/template/README.md | 96 ++++++++++++++ flash/template/__init__.py | 0 flash/template/data.py | 234 ++++++++++++++++++++++++++++++++++ flash/template/model.py | 130 +++++++++++++++++++ 6 files changed, 536 insertions(+) create mode 100644 docs/source/task_template.rst create mode 100644 flash/template/README.md create mode 100644 flash/template/__init__.py create mode 100644 flash/template/data.py create mode 100644 flash/template/model.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 49de343ea3..836fb7003e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,6 +13,7 @@ Lightning Flash quickstart installation custom_task + task_template reference/flash_to_pl .. toctree:: diff --git a/docs/source/task_template.rst b/docs/source/task_template.rst new file mode 100644 index 0000000000..01f46ebab0 --- /dev/null +++ b/docs/source/task_template.rst @@ -0,0 +1,75 @@ +Task Template +============= + +This template is designed to guide you through implementing your own task in flash. +You should copy the files in ``flash/template`` and adapt them to your custom task. + +Required Files +-------------- + +``data.py`` +~~~~~~~~~~~ + +Inside ``data.py`` you should implement: + +#. one or more :class:`~flash.data.data_source.DataSource` classes +#. a :class:`~flash.data.process.Preprocess` +#. a :class:`~flash.data.data_module.DataModule` +#. a :class:`~flash.data.callbacks.BaseVisualization` *(optional)* +#. a :class:`~flash.data.process.Postprocess` *(optional)* + +``DataSource`` +^^^^^^^^^^^^^^ + +The :class:`~flash.data.data_source.DataSource` implementations describe how data from particular sources (like folders, files, tensors, etc.) should be loaded. +At a minimum you will require one :class:`~flash.data.data_source.DataSource` implementation, but if you want to support a few different ways of loading data for your task, the more the merrier! + +Take a look at our ``TemplateNumpyDataSource`` to get started: + +.. raw:: html + +
+ Click to expand + +.. autoclass:: flash.template.data.TemplateNumpyDataSource + :members: + +.. raw:: html + +
+ +And have a look at our ``TemplateSKLearnDataSource`` for another example: + +.. raw:: html + +
+ Click to expand + +.. autoclass:: flash.template.data.TemplateSKLearnDataSource + :members: + +.. raw:: html + +
+ +``Preprocess`` +^^^^^^^^^^^^^^ + +The ``Preprocess`` is how all transforms are defined in Flash. +Internally we inject the ``Preprocess`` transforms into the right places so that we can address the batch at several points along the pipeline. +Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your ``Preprocess`` is as simple as implementing the ``default_transforms`` method. +The ``Preprocess`` also knows about the available `DataSource` classes that it can work with, which should be configured in the ``__init__``. + +Take a look at our ``TemplatePreprocess`` to get started: + +.. raw:: html + +
+ Click to expand + +.. autoclass:: flash.template.data.TemplatePreprocess + :members: + +.. raw:: html + +
diff --git a/flash/template/README.md b/flash/template/README.md new file mode 100644 index 0000000000..834c376d07 --- /dev/null +++ b/flash/template/README.md @@ -0,0 +1,96 @@ +# Lightning Flash Task Template + +This template is designed to guide you through implementing your own task in flash. +You should copy the files here and adapt them to your custom task. + +## Required Files: + +### `data.py` + +Inside `data.py` you should implement: + +1. one or more `DataSource` classes +2. a `Preprocess` +3. a `DataModule` +4. a `BaseVisualization` __(optional)__ +5. a `Postprocess` __(optional)__ + +#### `DataSource` + +The `DataSource` implementations describe how data from particular sources (like folders, files, tensors, etc.) should be loaded. +At a minimum you will require one `DataSource` implementation, but if you want to support a few different ways of loading data for your task, the more the merrier! +Take a look at our `TemplateDataSource` to get started. + +#### `Preprocess` + +The `Preprocess` is how all transforms are defined in Flash. +Internally we inject the `Preprocess` transforms into the right places so that we can address the batch at several points along the pipeline. +Defining the standard transforms (typically at least a `to_tensor_transform` and a `collate` should be defined) for your `Preprocess` is as simple as implementing the `default_transforms` method. +The `Preprocess` also knows about the available `DataSource` classes that it can work with, which should be configured in the `__init__`. +Take a look at our `TemplatePreprocess` to get started. + +#### `DataModule` + +The `DataModule` is where the hard work of our `DataSource` and `Preprocess` implementations pays off. +If your `DataSource` implementation(s) conform to our `DefaultDataSources` (e.g. `DefaultDataSources.FOLDERS`) then your `DataModule` implementation simply needs a `preprocess_cls` attribute. +You now have a `DataModule` that can be instantiated with `from_*` for whichever data sources you have configured (e.g. `MyDataModule.from_folders`). +It also includes all of your default transforms! + +If you've defined a fully custom `DataSource`, then you will need a `preprocess_cls` attribute and one or more `from_*` methods. +The `from_*` methods take whatever arguments you want them too and call `super().from_data_source` with the name given to your custom data source in the `Preprocess.__init__`. +Take a look at our `TemplateDataModule` to get started. + +#### `BaseVisualization` + +A completely optional step is to implement a `BaseVisualization`. The `BaseVisualization` lets you control how data at various points in the pipeline can be visualized. +This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms. +Take a look at our `TemplateVisualization` to get started, but don't worry about implementing it right away, you can always come back and add it later! + +#### `Postprocess` + +Sometimes you have some transforms that need to be applied _after_ your model. +For this you can optionally implement a `Postprocess`. +The `Postprocess` is applied to the model outputs during inference. +You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. +For information and some examples, take a look at our postprocess docs. + +### `model.py` + +Inside `model.py` you just need to implement your `Task`. + +#### `Task` + +The `Task` is responsible for the forward pass of the model. +It's just a `LightningModule` with some helpful defaults, so anything you can do inside a `LightningModule` you can do inside a `Task`. +You should configure a default loss function and optimizer and some default metrics and models in your `Task`. +Take a look at our `TemplateTask` to get started. + +### `flash_examples` + +Now you've implemented your task, it's time to add some examples showing how cool it is! +We usually provide one finetuning example (in `flash_examples/finetuning`) and one predict / inference example (in `flash_examples/predict`). +You can base these off of our `template.py` examples. + +## Optional Files: + +### `transforms.py` + +Sometimes you'd like to have quite a few transforms by default (standard augmentations, normalization, etc.). +If you do then, for better organization, you can define a `transforms.py` which houses your default transforms to be referenced in your `Preprocess`. +Take a look at `vision/classification/transforms.py` for an example. + +### `backbones.py` + +In Flash, we love to provide as much access to the state-of-the-art as we can. +To this end, we've created the `FlashRegistry` and the backbones API. +These allow you to register backbones for your task that can be selected by the user. +The backbones can come from anywhere as long as you can register a function that loads the backbone. +If you want to configure some backbones for your task, it's best practice to include these in a `backbones.py` file. +Take a look at `vision/backbones.py` for an example, and have a look at `vision/classification/model.py` to see how these can be added to your `Task`. + +### `serialization.py` + +Sometimes you want to give the user some control over their prediction format. +`Postprocess` can do the heavy lifting (anything you always want to apply to the predictions), but one or more custom `Serializer` implementations can be used to convert the predictions to a desired output format. +A good example is in classification; sometimes we'd like the classes, sometimes the logits, sometimes the labels, you get the idea. +You should add your `Serializer` implementations in a `serialization.py` file and set a good default in your `Task`. diff --git a/flash/template/__init__.py b/flash/template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/template/data.py b/flash/template/data.py new file mode 100644 index 0000000000..599910878a --- /dev/null +++ b/flash/template/data.py @@ -0,0 +1,234 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple + +import numpy as np +import torch +from pytorch_lightning.trainer.states import RunningStage +from sklearn.utils import Bunch +from torch import nn + +from flash.data.base_viz import BaseVisualization +from flash.data.callback import BaseDataFetcher +from flash.data.data_module import DataModule +from flash.data.data_source import DefaultDataKeys, DefaultDataSources, LabelsState, NumpyDataSource +from flash.data.process import Preprocess +from flash.data.transforms import ApplyToKeys + + +class TemplateNumpyDataSource(NumpyDataSource): + """An example data source that records ``num_features`` on the dataset. We extend + :class:`~flash.data.data_source.NumpyDataSource` so that we can use ``super().load_data``.""" + + def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]: + """The main :class:`~flash.data.data_source.DataSource` method that we have to implement is + :meth:`~flash.data.data_source.DataSource.load_data`. As we're extending the ``NumpyDataSource``, we expect the + same ``data`` argument (in this case, a tuple containing data and corresponding target arrays). + + We can also take the dataset argument. Any attributes we set on ``dataset`` will be available on the ``Dataset`` + generated by our ``DataSource``. In this data source, we'll set the ``num_features`` attribute. Have a look at + our ``DataModule`` implementation to see how we make ``num_features`` available there. + + Args: + data: The tuple of ``np.ndarray`` (num_examples x num_features) and associated targets. + dataset: The object that we can set attributes (such as ``num_features``) on. + + Returns: + A sequence of samples / sample metadata. + + Source + .. literalinclude:: ../../flash/template/data.py + :language: python + :lines: 34, 55-58 + """ + dataset.num_features = data[0].shape[1] + + # Now we just call super + return super().load_data(data, dataset) + + +class TemplateSKLearnDataSource(TemplateNumpyDataSource): + """An example data source that loads data from an sklearn data ``Bunch``.""" + + def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]: + """Here we're creating a fully custom :class:`~flash.data.data_source.DataSource` (that is, we're not going to + treat it as one of the :class:`~flash.data.data_source.DefaultDataSources`) so the type of the ``data`` argument + is up to us. In this case, we want to be able to use a scikit-learn data ``Bunch`` as an input. + + On our dataset, we'll set the ``num_classes`` attribute. This is a standard practice in our classification tasks + and ``num_classes`` will automatically be made available by the :class:`~flash.data.data_module.DataModule`. + + Args: + data: The scikit-learn data ``Bunch``. + dataset: The object that we can set attributes (such as ``num_classes``) on. + + Returns: + A sequence of samples / sample metadata. + + Source + .. literalinclude:: ../../flash/template/data.py + :language: python + :lines: 64, 84-92 + """ + dataset.num_classes = len(data.target_names) + + # To share metadata between the internal components (``Preprocess``, ``Postprocess``, ``Serializer``, etc.) we + # can use ``self.set_state``. For classification tasks, where we know the labels we should set the + # ``LabelsState``. This enables the ``Labels`` serializer without it needing to be told the labels by the user. + self.set_state(LabelsState(data.target_names)) + + # Now we just call super with the data and targets + return super().load_data((data.data, data.target), dataset=dataset) + + +class TemplatePreprocess(Preprocess): + """The next thing for us to implement is the :class:`~flash.data.process.Preprocess`. The + :class:`~flash.data.process.Preprocess` must take ``train_transform``, ``val_transform``, ``test_transform``, and + ``predict_transform`` arguments in the ``__init__``. Any additional arguments for it to take are up to you. + + Inside the ``__init__``, we make a call to super. This is where we register our data sources. Data sources should be + given as a dictionary which maps data source name to data source object. The name can be anything, but if you want + to take advantage of our built-in ``from_*`` classmethods, you should use + :class:`~flash.data.data_source.DefaultDataSources` as the names. In our case, we have both a + :attr:`~flash.data.data_source.DefaultDataSources.NUMPY` and a custom scikit-learn data source (which we'll call + "sklearn"). + + We can also provide a ``default_data_source``. This is the name of the data source to use by default when + predicting. It'd be cool if we could get predictions just from a numpy array, so let's use the + :attr:`~flash.data.data_source.DefaultDataSources.NUMPY` as the default. + + Args: + train_transform: The user-specified transforms to apply during training. + val_transform: The user-specified transforms to apply during validation. + test_transform: The user-specified transforms to apply during testing. + predict_transform: The user-specified transforms to apply during prediction. + + Source + .. literalinclude:: ../../flash/template/data.py + :language: python + :lines: 123-140 + """ + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.NUMPY: TemplateNumpyDataSource(), + "sklearn": TemplateSKLearnDataSource(), + }, + default_data_source=DefaultDataSources.NUMPY, + ) + + def get_state_dict(self) -> Dict[str, Any]: + """For serialization, you have control over what to save with the ``get_state_dict`` method. It's usually a good + idea to save the transforms. So we just return them here. If you had any other attributes you wanted to save, + this is where you would return them. + + Source + .. literalinclude:: ../../flash/template/data.py + :language: python + :lines: 142, 152 + """ + return self.transforms + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + """This methods gets whatever we returned from ``get_state_dict`` as an input. Now we re-create the class with + the transforms we saved. + + Source + .. literalinclude:: ../../flash/template/data.py + :language: python + :lines: 154, 155, 164 + """ + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + """Your :class:`~flash.data.process.Preprocess` should usually define some default transforms. Generally, we at + least want to convert to a tensor, so let's do that here. + + Our inputs samples will be dictionaries whose keys are from the + :class:`~flash.data.data_source.DefaultDataKeys`, so we need to map each key to different transforms using + :class:`~flash.data.transforms.ApplyToKeys`. By convention, we apply sequences of transforms by wrapping them in + an ``nn.Sequential``. + + Returns: + Our dictionary of transforms. + + Source + .. literalinclude:: ../../flash/template/data.py + :language: python + :lines: 166, 183-188 + """ + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torch.from_numpy), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + } + + # If we wanted to apply different transforms at a particular stage (e.g. during training), we can prepend: `train`, + # `val`, `test`, or `predict`, and provide some different defaults linke this: + # def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + + +class TemplateData(DataModule): + """Creating our :class:`~flash.data.data_module.DataModule` is as easy as setting the ``preprocess_cls`` attribute. + We'll also add the ``num_features`` property for convenience.""" + + preprocess_cls = TemplatePreprocess + + @property + def num_features(self) -> Optional[int]: + """Tries to get the ``num_features`` from each dataset in turn and returns the output.""" + return ( + getattr(self.train_dataset, "num_features", None) or getattr(self.val_dataset, "num_features", None) + or getattr(self.test_dataset, "num_features", None) + ) + + # OPTIONAL - Everything from this point onwards is an optional extra + + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + """We can also optionally provide a data visualization callback using the ``configure_data_fetcher`` method.""" + return TemplateVisualization(*args, **kwargs) + + +class TemplateVisualization(BaseVisualization): + """The ``TemplateVisualization`` class is a :class:`~flash.data.callbacks.BaseVisualization` that just prints the + data. If you want to provide a visualization with your task, you can override these hooks.""" + + def show_load_sample(self, samples: List[Any], running_stage: RunningStage): + print(samples) + + def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + print(samples) + + def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + print(samples) + + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + print(samples) + + def show_per_batch_transform(self, batch: List[Any], running_stage): + print(batch) diff --git a/flash/template/model.py b/flash/template/model.py new file mode 100644 index 0000000000..36fb1808fa --- /dev/null +++ b/flash/template/model.py @@ -0,0 +1,130 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import FunctionType +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union + +import torch +import torchmetrics +from torch import nn +from torch.optim.lr_scheduler import _LRScheduler + +from flash.core.classification import ClassificationTask +from flash.core.registry import FlashRegistry +from flash.data.data_source import DefaultDataKeys +from flash.data.process import Serializer +from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES + + +class ImageClassifier(ClassificationTask): + """Task that classifies images. + + Use a built in backbone + + Example:: + + from flash.vision import ImageClassifier + + classifier = ImageClassifier(backbone='resnet18') + + Or your own backbone (num_features is the number of features produced by your backbone) + + Example:: + + from flash.vision import ImageClassifier + from torch import nn + + # use any backbone + some_backbone = nn.Conv2D(...) + num_out_features = 1024 + classifier = ImageClassifier(backbone=(some_backbone, num_out_features)) + + + Args: + num_classes: Number of classes to classify. + backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``. + pretrained: Use a pretrained backbone, defaults to ``True``. + loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. + optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. + metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.Accuracy`. + learning_rate: Learning rate to use for training, defaults to ``1e-3``. + multi_label: Whether the targets are multi-label or not. + serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs. + """ + + backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES + + def __init__( + self, + num_classes: int, + backbone: Union[str, Tuple[nn.Module, int]] = "resnet18", + backbone_kwargs: Optional[Dict] = None, + head: Optional[Union[FunctionType, nn.Module]] = None, + pretrained: bool = True, + loss_fn: Optional[Callable] = None, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-3, + multi_label: bool = False, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + ): + super().__init__( + model=None, + loss_fn=loss_fn, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + metrics=metrics, + learning_rate=learning_rate, + multi_label=multi_label, + serializer=serializer, + ) + + self.save_hyperparameters() + + if not backbone_kwargs: + backbone_kwargs = {} + + if isinstance(backbone, tuple): + self.backbone, num_features = backbone + else: + self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) + + head = head(num_features, num_classes) if isinstance(head, FunctionType) else head + self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), ) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + def forward(self, x) -> torch.Tensor: + x = self.backbone(x) + if x.dim() == 4: + x = x.mean(-1).mean(-1) + return self.head(x) From 4cccd797fe44140f4da71a86e1aa3d229bf07a25 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 14 May 2021 17:39:57 +0100 Subject: [PATCH 02/53] Updates --- docs/source/task_template.rst | 112 ++++++++++++++++++++++++++++++---- flash/template/data.py | 106 ++++++++++++++++++++++---------- 2 files changed, 174 insertions(+), 44 deletions(-) diff --git a/docs/source/task_template.rst b/docs/source/task_template.rst index 01f46ebab0..7ef3f9c19d 100644 --- a/docs/source/task_template.rst +++ b/docs/source/task_template.rst @@ -4,6 +4,10 @@ Task Template This template is designed to guide you through implementing your own task in flash. You should copy the files in ``flash/template`` and adapt them to your custom task. +.. contents:: Contents: + :local: + :depth: 3 + Required Files -------------- @@ -26,27 +30,37 @@ At a minimum you will require one :class:`~flash.data.data_source.DataSource` im Take a look at our ``TemplateNumpyDataSource`` to get started: +.. autoclass:: flash.template.data.TemplateNumpyDataSource + :members: + .. raw:: html
- Click to expand + Source -.. autoclass:: flash.template.data.TemplateNumpyDataSource - :members: +.. literalinclude:: ../../flash/template/data.py + :language: python + :pyobject: TemplateNumpyDataSource .. raw:: html
+| + And have a look at our ``TemplateSKLearnDataSource`` for another example: +.. autoclass:: flash.template.data.TemplateSKLearnDataSource + :members: + .. raw:: html
- Click to expand + Source -.. autoclass:: flash.template.data.TemplateSKLearnDataSource - :members: +.. literalinclude:: ../../flash/template/data.py + :language: python + :pyobject: TemplateSKLearnDataSource .. raw:: html @@ -55,21 +69,95 @@ And have a look at our ``TemplateSKLearnDataSource`` for another example: ``Preprocess`` ^^^^^^^^^^^^^^ -The ``Preprocess`` is how all transforms are defined in Flash. -Internally we inject the ``Preprocess`` transforms into the right places so that we can address the batch at several points along the pipeline. -Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your ``Preprocess`` is as simple as implementing the ``default_transforms`` method. -The ``Preprocess`` also knows about the available `DataSource` classes that it can work with, which should be configured in the ``__init__``. +The :class:`~flash.data.process.Preprocess` is how all transforms are defined in Flash. +Internally we inject the :class:`~flash.data.process.Preprocess` transforms into the right places so that we can address the batch at several points along the pipeline. +Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your :class:`~flash.data.process.Preprocess` is as simple as implementing the ``default_transforms`` method. +The :class:`~flash.data.process.Preprocess` also knows about the available :class:`~flash.data.data_source.DataSource` classes that it can work with, which should be configured in the ``__init__``. Take a look at our ``TemplatePreprocess`` to get started: +.. autoclass:: flash.template.data.TemplatePreprocess + :members: + .. raw:: html
- Click to expand + Source -.. autoclass:: flash.template.data.TemplatePreprocess +.. literalinclude:: ../../flash/template/data.py + :language: python + :pyobject: TemplatePreprocess + +.. raw:: html + +
+ +``DataModule`` +^^^^^^^^^^^^^^ + +The :class:`~flash.data.data_module.DataModule` is where the hard work of our :class:`~flash.data.data_source.DataSource` and :class:`~flash.data.process.Preprocess` implementations pays off. +If your :class:`~flash.data.data_source.DataSource` implementation(s) conform to our :class:`~flash.data.data_source.DefaultDataSources` (e.g. ``DefaultDataSources.FOLDERS``) then your :class:`~flash.data.data_module.DataModule` implementation simply needs a ``preprocess_cls`` attribute. +You now have a :class:`~flash.data.data_module.DataModule` that can be instantiated with ``from_*`` for whichever data sources you have configured (e.g. ``MyDataModule.from_folders``). +It also includes all of your default transforms! + +If you've defined a fully custom :class:`~flash.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need a ``from_*`` method for each (we'll define ``from_sklearn`` for our example). +The ``from_*`` methods take whatever arguments you want them too and call ``super().from_data_source`` with the name given to your custom data source in the ``Preprocess.__init__``. + + +Take a look at our ``TemplateData`` to get started: + +.. autoclass:: flash.template.data.TemplateData :members: +.. raw:: html + +
+ Source + +.. literalinclude:: ../../flash/template/data.py + :language: python + :pyobject: TemplateData + .. raw:: html
+ +``BaseVisualization`` +^^^^^^^^^^^^^^^^^^^^^ + +An optional step is to implement a ``BaseVisualization``. The ``BaseVisualization`` lets you control how data at various points in the pipeline can be visualized. +This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms. + +Take a look at our ``TemplateVisualization`` to get started: + +.. note:: + Don't worry about implementing it right away, you can always come back and add it later! + +.. autoclass:: flash.template.data.TemplateVisualization + :members: + +.. raw:: html + +
+ Source + +.. literalinclude:: ../../flash/template/data.py + :language: python + :pyobject: TemplateVisualization + +.. raw:: html + +
+ +``Postprocess`` +^^^^^^^^^^^^^^^ + +Sometimes you have some transforms that need to be applied _after_ your model. +For this you can optionally implement a :class:`~flash.data.process.Postprocess`. +The :class:`~flash.data.process.Postprocess` is applied to the model outputs during inference. +You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. + +`model.py` +~~~~~~~~~~ + +Inside `model.py` you just need to implement your `Task`. diff --git a/flash/template/data.py b/flash/template/data.py index 599910878a..7f7a318b2a 100644 --- a/flash/template/data.py +++ b/flash/template/data.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple import numpy as np +import sklearn.utils import torch from pytorch_lightning.trainer.states import RunningStage from sklearn.utils import Bunch @@ -46,11 +47,6 @@ def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Seq Returns: A sequence of samples / sample metadata. - - Source - .. literalinclude:: ../../flash/template/data.py - :language: python - :lines: 34, 55-58 """ dataset.num_features = data[0].shape[1] @@ -75,11 +71,6 @@ def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]: Returns: A sequence of samples / sample metadata. - - Source - .. literalinclude:: ../../flash/template/data.py - :language: python - :lines: 64, 84-92 """ dataset.num_classes = len(data.target_names) @@ -113,11 +104,6 @@ class TemplatePreprocess(Preprocess): val_transform: The user-specified transforms to apply during validation. test_transform: The user-specified transforms to apply during testing. predict_transform: The user-specified transforms to apply during prediction. - - Source - .. literalinclude:: ../../flash/template/data.py - :language: python - :lines: 123-140 """ def __init__( @@ -143,11 +129,6 @@ def get_state_dict(self) -> Dict[str, Any]: """For serialization, you have control over what to save with the ``get_state_dict`` method. It's usually a good idea to save the transforms. So we just return them here. If you had any other attributes you wanted to save, this is where you would return them. - - Source - .. literalinclude:: ../../flash/template/data.py - :language: python - :lines: 142, 152 """ return self.transforms @@ -155,11 +136,6 @@ def get_state_dict(self) -> Dict[str, Any]: def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): """This methods gets whatever we returned from ``get_state_dict`` as an input. Now we re-create the class with the transforms we saved. - - Source - .. literalinclude:: ../../flash/template/data.py - :language: python - :lines: 154, 155, 164 """ return cls(**state_dict) @@ -174,11 +150,6 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: Returns: Our dictionary of transforms. - - Source - .. literalinclude:: ../../flash/template/data.py - :language: python - :lines: 166, 183-188 """ return { "to_tensor_transform": nn.Sequential( @@ -194,10 +165,81 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: class TemplateData(DataModule): """Creating our :class:`~flash.data.data_module.DataModule` is as easy as setting the ``preprocess_cls`` attribute. - We'll also add the ``num_features`` property for convenience.""" + We get the ``from_numpy`` method for free as we've configured a ``DefaultDataSources.NUMPY`` data source. We'll also + add a ``from_sklearn`` method so that we can use our ``TemplateSKLearnDataSource. Finally, we define the + ``num_features`` property for convenience. + """ preprocess_cls = TemplatePreprocess + @classmethod + def from_sklearn( + cls, + train_bunch: Optional[Bunch] = None, + val_bunch: Optional[Bunch] = None, + test_bunch: Optional[Bunch] = None, + predict_bunch: Optional[Bunch] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them + through to the :meth:`~flash.data.data_module.DataModule.from_data_source` method underneath. It's really just + a convenience method to save the user from needing to call + :meth:`~flash.data.data_module.DataModule.from_data_source` directly. + + Args: + train_bunch: The scikit-learn ``Bunch`` containing the train data. + val_bunch: The scikit-learn ``Bunch`` containing the validation data. + test_bunch: The scikit-learn ``Bunch`` containing the test data. + predict_bunch: The scikit-learn ``Bunch`` containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + """ + return super().from_data_source( + "sklearn", + train_bunch, + val_bunch, + test_bunch, + predict_bunch, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + @property def num_features(self) -> Optional[int]: """Tries to get the ``num_features`` from each dataset in turn and returns the output.""" @@ -210,7 +252,7 @@ def num_features(self) -> Optional[int]: @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - """We can also optionally provide a data visualization callback using the ``configure_data_fetcher`` method.""" + """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher`` method.""" return TemplateVisualization(*args, **kwargs) From 838e40cd862b1d8074d35bf69069c525f315d1a9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 14 May 2021 21:10:31 +0100 Subject: [PATCH 03/53] Updates --- docs/source/index.rst | 13 ++++- .../{task_template.rst => template/data.rst} | 56 ++++++++----------- docs/source/template/examples.rst | 16 ++++++ docs/source/template/intro.rst | 26 +++++++++ docs/source/template/model.rst | 15 +++++ docs/source/template/optional.rst | 30 ++++++++++ docs/source/template/tests.rst | 5 ++ 7 files changed, 126 insertions(+), 35 deletions(-) rename docs/source/{task_template.rst => template/data.rst} (84%) create mode 100644 docs/source/template/examples.rst create mode 100644 docs/source/template/intro.rst create mode 100644 docs/source/template/model.rst create mode 100644 docs/source/template/optional.rst create mode 100644 docs/source/template/tests.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 836fb7003e..1ce28229d6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,7 +13,6 @@ Lightning Flash quickstart installation custom_task - task_template reference/flash_to_pl .. toctree:: @@ -50,6 +49,18 @@ Lightning Flash general/finetuning general/predictions + +.. toctree:: + :maxdepth: 1 + :caption: Contributing a Task + + template/intro + template/data + template/model + template/optional + template/examples + template/tests + Indices and tables ================== diff --git a/docs/source/task_template.rst b/docs/source/template/data.rst similarity index 84% rename from docs/source/task_template.rst rename to docs/source/template/data.rst index 7ef3f9c19d..f10ed7b119 100644 --- a/docs/source/task_template.rst +++ b/docs/source/template/data.rst @@ -1,20 +1,11 @@ -Task Template -============= +.. _contributing_data: -This template is designed to guide you through implementing your own task in flash. -You should copy the files in ``flash/template`` and adapt them to your custom task. +******** +The Data +******** -.. contents:: Contents: - :local: - :depth: 3 - -Required Files --------------- - -``data.py`` -~~~~~~~~~~~ - -Inside ``data.py`` you should implement: +The first step to contributing a task is to implement the classes we need to load some data. +Inside ``data.py`` you we implement: #. one or more :class:`~flash.data.data_source.DataSource` classes #. a :class:`~flash.data.process.Preprocess` @@ -22,8 +13,8 @@ Inside ``data.py`` you should implement: #. a :class:`~flash.data.callbacks.BaseVisualization` *(optional)* #. a :class:`~flash.data.process.Postprocess` *(optional)* -``DataSource`` -^^^^^^^^^^^^^^ +DataSource +^^^^^^^^^^ The :class:`~flash.data.data_source.DataSource` implementations describe how data from particular sources (like folders, files, tensors, etc.) should be loaded. At a minimum you will require one :class:`~flash.data.data_source.DataSource` implementation, but if you want to support a few different ways of loading data for your task, the more the merrier! @@ -38,7 +29,7 @@ Take a look at our ``TemplateNumpyDataSource`` to get started:
Source -.. literalinclude:: ../../flash/template/data.py +.. literalinclude:: ../../../flash/template/data.py :language: python :pyobject: TemplateNumpyDataSource @@ -58,7 +49,7 @@ And have a look at our ``TemplateSKLearnDataSource`` for another example:
Source -.. literalinclude:: ../../flash/template/data.py +.. literalinclude:: ../../../flash/template/data.py :language: python :pyobject: TemplateSKLearnDataSource @@ -66,8 +57,8 @@ And have a look at our ``TemplateSKLearnDataSource`` for another example:
-``Preprocess`` -^^^^^^^^^^^^^^ +Preprocess +^^^^^^^^^^ The :class:`~flash.data.process.Preprocess` is how all transforms are defined in Flash. Internally we inject the :class:`~flash.data.process.Preprocess` transforms into the right places so that we can address the batch at several points along the pipeline. @@ -84,7 +75,7 @@ Take a look at our ``TemplatePreprocess`` to get started:
Source -.. literalinclude:: ../../flash/template/data.py +.. literalinclude:: ../../../flash/template/data.py :language: python :pyobject: TemplatePreprocess @@ -92,8 +83,8 @@ Take a look at our ``TemplatePreprocess`` to get started:
-``DataModule`` -^^^^^^^^^^^^^^ +DataModule +^^^^^^^^^^ The :class:`~flash.data.data_module.DataModule` is where the hard work of our :class:`~flash.data.data_source.DataSource` and :class:`~flash.data.process.Preprocess` implementations pays off. If your :class:`~flash.data.data_source.DataSource` implementation(s) conform to our :class:`~flash.data.data_source.DefaultDataSources` (e.g. ``DefaultDataSources.FOLDERS``) then your :class:`~flash.data.data_module.DataModule` implementation simply needs a ``preprocess_cls`` attribute. @@ -114,7 +105,7 @@ Take a look at our ``TemplateData`` to get started:
Source -.. literalinclude:: ../../flash/template/data.py +.. literalinclude:: ../../../flash/template/data.py :language: python :pyobject: TemplateData @@ -122,8 +113,8 @@ Take a look at our ``TemplateData`` to get started:
-``BaseVisualization`` -^^^^^^^^^^^^^^^^^^^^^ +BaseVisualization +^^^^^^^^^^^^^^^^^ An optional step is to implement a ``BaseVisualization``. The ``BaseVisualization`` lets you control how data at various points in the pipeline can be visualized. This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms. @@ -141,7 +132,7 @@ Take a look at our ``TemplateVisualization`` to get started:
Source -.. literalinclude:: ../../flash/template/data.py +.. literalinclude:: ../../../flash/template/data.py :language: python :pyobject: TemplateVisualization @@ -149,15 +140,12 @@ Take a look at our ``TemplateVisualization`` to get started:
-``Postprocess`` -^^^^^^^^^^^^^^^ +Postprocess +^^^^^^^^^^^ Sometimes you have some transforms that need to be applied _after_ your model. For this you can optionally implement a :class:`~flash.data.process.Postprocess`. The :class:`~flash.data.process.Postprocess` is applied to the model outputs during inference. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. -`model.py` -~~~~~~~~~~ - -Inside `model.py` you just need to implement your `Task`. +:ref:`Now that you've got some data, it's time to implement your task! ` diff --git a/docs/source/template/examples.rst b/docs/source/template/examples.rst new file mode 100644 index 0000000000..93631cfc13 --- /dev/null +++ b/docs/source/template/examples.rst @@ -0,0 +1,16 @@ +.. _contributing_examples: + +************ +The Examples +************ + +Now you've implemented your task, it's time to add some examples showing how cool it is! +We usually provide one finetuning example (in `flash_examples/finetuning`) and one predict / inference example (in `flash_examples/predict`). +You can base these off of our `template.py` examples. +Let's take a closer look. + +finetuning +========== + +predict +======= diff --git a/docs/source/template/intro.rst b/docs/source/template/intro.rst new file mode 100644 index 0000000000..99b68058c2 --- /dev/null +++ b/docs/source/template/intro.rst @@ -0,0 +1,26 @@ +.. _contributing: + +************ +Introduction +************ + +Welcome +======= + +Before you begin, we'd like to express our sincere gratitude to you for wanting to add a task to Flash. +With Flash our aim is to create a great user experience, enabling awesome advanced applications with just a few lines of code. +We're really pleased with what we've achieved with Flash and we hope you will be too. +Now let's dive in! + +Tutorials +========= + +The Task template is designed to guide you through contributing a task to Flash. +You should copy the files in ``flash/template`` to get started. +The tutorials in this section will take you through all of the components you need to implement for your custom task. + +- :ref:`contributing_data`: our first tutorial goes over the best practices for implementing everything you need to connect data to your task +- :ref:`contributing_task`: now that we have the data, in this tutorial we create our custom task +- :ref:`contributing_optional`: this tutorial covers some optional extras you can add if needed for your particular task +- :ref:`contributing_examples`: this tutorial guides you through creating some simple examples showing your task in action +- :ref:`contributing_tests`: in our final tutorial, we cover best practices for writing some tests for your new task diff --git a/docs/source/template/model.rst b/docs/source/template/model.rst new file mode 100644 index 0000000000..d71ac1941c --- /dev/null +++ b/docs/source/template/model.rst @@ -0,0 +1,15 @@ +.. _contributing_task: + +******** +The Task +******** + +Inside ``model.py`` you just need to implement your ``Task``. +The `Task` is responsible for the forward pass of the model. +It's just a `LightningModule` with some helpful defaults, so anything you can do inside a `LightningModule` you can do inside a `Task`. + +Task +^^^^ + +You should configure a default loss function and optimizer and some default metrics and models in your `Task`. +Take a look at our `TemplateTask` to get started. diff --git a/docs/source/template/optional.rst b/docs/source/template/optional.rst new file mode 100644 index 0000000000..0593b42c90 --- /dev/null +++ b/docs/source/template/optional.rst @@ -0,0 +1,30 @@ +.. _contributing_optional: + +*************** +Optional Extras +*************** + +transforms.py +============= + +Sometimes you'd like to have quite a few transforms by default (standard augmentations, normalization, etc.). +If you do then, for better organization, you can define a `transforms.py` which houses your default transforms to be referenced in your `Preprocess`. +Take a look at `vision/classification/transforms.py` for an example. + +backbones.py +============ + +In Flash, we love to provide as much access to the state-of-the-art as we can. +To this end, we've created the `FlashRegistry` and the backbones API. +These allow you to register backbones for your task that can be selected by the user. +The backbones can come from anywhere as long as you can register a function that loads the backbone. +If you want to configure some backbones for your task, it's best practice to include these in a `backbones.py` file. +Take a look at `vision/backbones.py` for an example, and have a look at `vision/classification/model.py` to see how these can be added to your `Task`. + +serialization.py +================ + +Sometimes you want to give the user some control over their prediction format. +`Postprocess` can do the heavy lifting (anything you always want to apply to the predictions), but one or more custom `Serializer` implementations can be used to convert the predictions to a desired output format. +A good example is in classification; sometimes we'd like the classes, sometimes the logits, sometimes the labels, you get the idea. +You should add your `Serializer` implementations in a `serialization.py` file and set a good default in your `Task`. diff --git a/docs/source/template/tests.rst b/docs/source/template/tests.rst new file mode 100644 index 0000000000..5652499546 --- /dev/null +++ b/docs/source/template/tests.rst @@ -0,0 +1,5 @@ +.. _contributing_tests: + +********* +The Tests +********* From 77353493b1abcd44a7a4e19169b76a9719293aab Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 12:08:09 +0100 Subject: [PATCH 04/53] Updates --- docs/source/template/data.rst | 27 +++++++-- docs/source/template/intro.rst | 4 +- docs/source/template/model.rst | 35 +++++++++-- flash/template/data.py | 13 ++--- flash/template/model.py | 103 +++++++++++++-------------------- 5 files changed, 102 insertions(+), 80 deletions(-) diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index f10ed7b119..3c4b19bf0d 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -5,7 +5,7 @@ The Data ******** The first step to contributing a task is to implement the classes we need to load some data. -Inside ``data.py`` you we implement: +Inside ``data.py`` you should implement: #. one or more :class:`~flash.data.data_source.DataSource` classes #. a :class:`~flash.data.process.Preprocess` @@ -18,8 +18,13 @@ DataSource The :class:`~flash.data.data_source.DataSource` implementations describe how data from particular sources (like folders, files, tensors, etc.) should be loaded. At a minimum you will require one :class:`~flash.data.data_source.DataSource` implementation, but if you want to support a few different ways of loading data for your task, the more the merrier! +Each :class:`~flash.data.data_source.DataSource` has a ``load_data`` method and a ``load_sample`` method. +The ``load_data`` method accepts some dataset metadata (e.g. a folder name) and produces a sequence or iterable of samples or sample metadata. +The ``load_sample`` method then takes as input a single element from the output of ``load_data`` and returns a sample. +By default, these methods just return their input as you will not always need both to create a functioning :class:`~flash.data.data_source.DataSource`. -Take a look at our ``TemplateNumpyDataSource`` to get started: +I'ts best practice to just override one of our existing :class:`~flash.data.data_source.DataSource` classes where possible. +Take a look at our ``TemplateNumpyDataSource`` which does this to get started: .. autoclass:: flash.template.data.TemplateNumpyDataSource :members: @@ -39,7 +44,7 @@ Take a look at our ``TemplateNumpyDataSource`` to get started: | -And have a look at our ``TemplateSKLearnDataSource`` for another example: +Sometimes you need to something a bit more custom, have a look at our ``TemplateSKLearnDataSource`` for an example: .. autoclass:: flash.template.data.TemplateSKLearnDataSource :members: @@ -57,6 +62,16 @@ And have a look at our ``TemplateSKLearnDataSource`` for another example:
+DataSource vs Dataset +~~~~~~~~~~~~~~~~~~~~~ + +A :class:`~flash.data.data_source.DataSource` is not the same as a :class:`torch.utils.data.Dataset`. +A :class:`torch.utils.data.Dataset` knows about the data, whereas a :class:`~flash.data.data_source.DataSource` only know about how to load the data. +So it's possible for a single :class:`~flash.data.data_source.DataSource` to create more than one :class:`~torch.utils.data.Dataset`. +It's also fine for the output of the ``load_data`` method to just be a :class:`torch.utils.data.Dataset` instance. +You don't need to re-write custom datasets, just instantiate them in ``load_data`` similarly to what we did with the ``TemplateSKLearnDataSource``. +For example, the ``load_data`` of the PyTorchVideo ``PathsDataSource`` just creates a :class:`pytorchvideo.data.encoded_video_dataset.EncodedVideoDataset` from the given folder. + Preprocess ^^^^^^^^^^ @@ -92,7 +107,7 @@ You now have a :class:`~flash.data.data_module.DataModule` that can be instantia It also includes all of your default transforms! If you've defined a fully custom :class:`~flash.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need a ``from_*`` method for each (we'll define ``from_sklearn`` for our example). -The ``from_*`` methods take whatever arguments you want them too and call ``super().from_data_source`` with the name given to your custom data source in the ``Preprocess.__init__``. +The ``from_*`` methods take whatever arguments you want them to and call :meth:`~flash.data.data_module.DataModule.from_data_source` with the name given to your custom data source in the ``Preprocess.__init__``. Take a look at our ``TemplateData`` to get started: @@ -148,4 +163,6 @@ For this you can optionally implement a :class:`~flash.data.process.Postprocess` The :class:`~flash.data.process.Postprocess` is applied to the model outputs during inference. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. -:ref:`Now that you've got some data, it's time to implement your task! ` +------ + +Now that you've got some data, it's time to :ref:`implement your task! ` diff --git a/docs/source/template/intro.rst b/docs/source/template/intro.rst index 99b68058c2..4dc097b3d0 100644 --- a/docs/source/template/intro.rst +++ b/docs/source/template/intro.rst @@ -16,8 +16,10 @@ Tutorials ========= The Task template is designed to guide you through contributing a task to Flash. +It contains the code, tests, and examples for a task that performs classification with a multi-layer perceptron, intended for use with the classic data sets from scikit-learn. You should copy the files in ``flash/template`` to get started. -The tutorials in this section will take you through all of the components you need to implement for your custom task. + +The tutorials in this section will walk you through all of the components you need to implement (or adapt from the template) for your custom task. - :ref:`contributing_data`: our first tutorial goes over the best practices for implementing everything you need to connect data to your task - :ref:`contributing_task`: now that we have the data, in this tutorial we create our custom task diff --git a/docs/source/template/model.rst b/docs/source/template/model.rst index d71ac1941c..6b0dc371e7 100644 --- a/docs/source/template/model.rst +++ b/docs/source/template/model.rst @@ -4,12 +4,37 @@ The Task ******** -Inside ``model.py`` you just need to implement your ``Task``. -The `Task` is responsible for the forward pass of the model. -It's just a `LightningModule` with some helpful defaults, so anything you can do inside a `LightningModule` you can do inside a `Task`. +Inside ``model.py`` you just need to implement your :class:`~flash.core.model.Task`. +The :class:`~flash.core.model.Task` is responsible for the forward pass of the model. +It's just a :any:`pytorch_lightning:lightning_module` with some helpful defaults, so anything you can do inside a :any:`pytorch_lightning:lightning_module` you can do inside a :class:`~flash.core.model.Task`. Task ^^^^ -You should configure a default loss function and optimizer and some default metrics and models in your `Task`. -Take a look at our `TemplateTask` to get started. +You should configure a default loss function and optimizer and some default metrics and models in your :class:`~flash.core.model.Task`. +For our scikit-learn example, we can just override :class:`~flash.core.classification.ClassificationTask` which provides these defaults for us. +You should also override the ``*_step`` methods to unpack your sample. +The default ``*_step`` implementations in :class:`~flash.core.model.Task` expect a tuple containing the input and target, and should be suitable for most applications. +In our template example, we just extract the input and target from the input mapping and forward them to the super methods. + +Here's our ``TemplateSKLearnClassifier``: + +.. autoclass:: flash.template.model.TemplateSKLearnClassifier + :members: + +.. raw:: html + +
+ Source + +.. literalinclude:: ../../../flash/template/model.py + :language: python + :pyobject: TemplateSKLearnClassifier + +.. raw:: html + +
+ +------ + +Now that you've got your task, take a look at some :ref:`optional advanced features you can add ` or go ahead and :ref:`create some examples showing your task in action! ` diff --git a/flash/template/data.py b/flash/template/data.py index 7f7a318b2a..da5de4d678 100644 --- a/flash/template/data.py +++ b/flash/template/data.py @@ -14,18 +14,17 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple import numpy as np -import sklearn.utils import torch from pytorch_lightning.trainer.states import RunningStage from sklearn.utils import Bunch from torch import nn -from flash.data.base_viz import BaseVisualization -from flash.data.callback import BaseDataFetcher -from flash.data.data_module import DataModule -from flash.data.data_source import DefaultDataKeys, DefaultDataSources, LabelsState, NumpyDataSource -from flash.data.process import Preprocess -from flash.data.transforms import ApplyToKeys +from flash.core.data.base_viz import BaseVisualization +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LabelsState, NumpyDataSource +from flash.core.data.process import Preprocess +from flash.core.data.transforms import ApplyToKeys class TemplateNumpyDataSource(NumpyDataSource): diff --git a/flash/template/model.py b/flash/template/model.py index 36fb1808fa..d8cbbe1821 100644 --- a/flash/template/model.py +++ b/flash/template/model.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from types import FunctionType -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union import torch import torchmetrics @@ -20,57 +19,37 @@ from torch.optim.lr_scheduler import _LRScheduler from flash.core.classification import ClassificationTask -from flash.core.registry import FlashRegistry -from flash.data.data_source import DefaultDataKeys -from flash.data.process import Serializer -from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Serializer -class ImageClassifier(ClassificationTask): - """Task that classifies images. - - Use a built in backbone - - Example:: - - from flash.vision import ImageClassifier - - classifier = ImageClassifier(backbone='resnet18') - - Or your own backbone (num_features is the number of features produced by your backbone) - - Example:: - - from flash.vision import ImageClassifier - from torch import nn - - # use any backbone - some_backbone = nn.Conv2D(...) - num_out_features = 1024 - classifier = ImageClassifier(backbone=(some_backbone, num_out_features)) - +class TemplateSKLearnClassifier(ClassificationTask): + """The ``TemplateSKLearnClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that uses a simple + multi-layer perceptron model to classify tabular data from scikit-learn. In the ``__init__``, we create our model + and pass it to the super :class:`~flash.core.model.Task` along with any arguments that we need. Args: - num_classes: Number of classes to classify. - backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``. - pretrained: Use a pretrained backbone, defaults to ``True``. - loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. - optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. - metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.Accuracy`. - learning_rate: Learning rate to use for training, defaults to ``1e-3``. - multi_label: Whether the targets are multi-label or not. - serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs. + num_features: The number of features (elements) in the input data. + num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`. + hidden_size: The number of units to use in the hidden layer of the multi-layer perceptron model. + loss_fn: The loss function to use. If ``None``, a default will be selected by the + :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + optimizer: The optimizer or optimizer class to use. + optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). + scheduler: The scheduler or scheduler class to use. + scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected + by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + learning_rate: The learning rate for the optimizer. + multi_label: If ``True``, this will be treated as a multi-label classification problem. + serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs. """ - backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES - def __init__( self, + num_features: int, num_classes: int, - backbone: Union[str, Tuple[nn.Module, int]] = "resnet18", - backbone_kwargs: Optional[Dict] = None, - head: Optional[Union[FunctionType, nn.Module]] = None, - pretrained: bool = True, + hidden_size: 128, loss_fn: Optional[Callable] = None, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -81,8 +60,14 @@ def __init__( multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): + model = nn.Sequential( + nn.Linear(num_features, hidden_size), + nn.ReLU(True), + nn.Linear(hidden_size, num_classes), + ) + super().__init__( - model=None, + model=model, loss_fn=loss_fn, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, @@ -96,35 +81,29 @@ def __init__( self.save_hyperparameters() - if not backbone_kwargs: - backbone_kwargs = {} - - if isinstance(backbone, tuple): - self.backbone, num_features = backbone - else: - self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) - - head = head(num_features, num_classes) if isinstance(head, FunctionType) else head - self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), ) - def training_step(self, batch: Any, batch_idx: int) -> Any: + """For the training step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` and + :attr:`~flash.core.data.data_source.DefaultDataKeys.TARGET` keys from the input and forward them to the + :meth:`~flash.core.model.Task.training_step`.""" batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) return super().training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: + """For the validation step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` and + :attr:`~flash.core.data.data_source.DefaultDataKeys.TARGET` keys from the input and forward them to the + :meth:`~flash.core.model.Task.validation_step`.""" batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) return super().validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: + """For the test step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` and + :attr:`~flash.core.data.data_source.DefaultDataKeys.TARGET` keys from the input and forward them to the + :meth:`~flash.core.model.Task.test_step`.""" batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + """For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key from + the input and forward it to the :meth:`~flash.core.model.Task.predict_step`.""" batch = (batch[DefaultDataKeys.INPUT]) return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) - - def forward(self, x) -> torch.Tensor: - x = self.backbone(x) - if x.dim() == 4: - x = x.mean(-1).mean(-1) - return self.head(x) From ce2108f00d792c6fa7ffff28d070dd272760834e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 12:10:46 +0100 Subject: [PATCH 05/53] Remove template README --- flash/template/README.md | 96 ---------------------------------------- 1 file changed, 96 deletions(-) delete mode 100644 flash/template/README.md diff --git a/flash/template/README.md b/flash/template/README.md deleted file mode 100644 index 834c376d07..0000000000 --- a/flash/template/README.md +++ /dev/null @@ -1,96 +0,0 @@ -# Lightning Flash Task Template - -This template is designed to guide you through implementing your own task in flash. -You should copy the files here and adapt them to your custom task. - -## Required Files: - -### `data.py` - -Inside `data.py` you should implement: - -1. one or more `DataSource` classes -2. a `Preprocess` -3. a `DataModule` -4. a `BaseVisualization` __(optional)__ -5. a `Postprocess` __(optional)__ - -#### `DataSource` - -The `DataSource` implementations describe how data from particular sources (like folders, files, tensors, etc.) should be loaded. -At a minimum you will require one `DataSource` implementation, but if you want to support a few different ways of loading data for your task, the more the merrier! -Take a look at our `TemplateDataSource` to get started. - -#### `Preprocess` - -The `Preprocess` is how all transforms are defined in Flash. -Internally we inject the `Preprocess` transforms into the right places so that we can address the batch at several points along the pipeline. -Defining the standard transforms (typically at least a `to_tensor_transform` and a `collate` should be defined) for your `Preprocess` is as simple as implementing the `default_transforms` method. -The `Preprocess` also knows about the available `DataSource` classes that it can work with, which should be configured in the `__init__`. -Take a look at our `TemplatePreprocess` to get started. - -#### `DataModule` - -The `DataModule` is where the hard work of our `DataSource` and `Preprocess` implementations pays off. -If your `DataSource` implementation(s) conform to our `DefaultDataSources` (e.g. `DefaultDataSources.FOLDERS`) then your `DataModule` implementation simply needs a `preprocess_cls` attribute. -You now have a `DataModule` that can be instantiated with `from_*` for whichever data sources you have configured (e.g. `MyDataModule.from_folders`). -It also includes all of your default transforms! - -If you've defined a fully custom `DataSource`, then you will need a `preprocess_cls` attribute and one or more `from_*` methods. -The `from_*` methods take whatever arguments you want them too and call `super().from_data_source` with the name given to your custom data source in the `Preprocess.__init__`. -Take a look at our `TemplateDataModule` to get started. - -#### `BaseVisualization` - -A completely optional step is to implement a `BaseVisualization`. The `BaseVisualization` lets you control how data at various points in the pipeline can be visualized. -This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms. -Take a look at our `TemplateVisualization` to get started, but don't worry about implementing it right away, you can always come back and add it later! - -#### `Postprocess` - -Sometimes you have some transforms that need to be applied _after_ your model. -For this you can optionally implement a `Postprocess`. -The `Postprocess` is applied to the model outputs during inference. -You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. -For information and some examples, take a look at our postprocess docs. - -### `model.py` - -Inside `model.py` you just need to implement your `Task`. - -#### `Task` - -The `Task` is responsible for the forward pass of the model. -It's just a `LightningModule` with some helpful defaults, so anything you can do inside a `LightningModule` you can do inside a `Task`. -You should configure a default loss function and optimizer and some default metrics and models in your `Task`. -Take a look at our `TemplateTask` to get started. - -### `flash_examples` - -Now you've implemented your task, it's time to add some examples showing how cool it is! -We usually provide one finetuning example (in `flash_examples/finetuning`) and one predict / inference example (in `flash_examples/predict`). -You can base these off of our `template.py` examples. - -## Optional Files: - -### `transforms.py` - -Sometimes you'd like to have quite a few transforms by default (standard augmentations, normalization, etc.). -If you do then, for better organization, you can define a `transforms.py` which houses your default transforms to be referenced in your `Preprocess`. -Take a look at `vision/classification/transforms.py` for an example. - -### `backbones.py` - -In Flash, we love to provide as much access to the state-of-the-art as we can. -To this end, we've created the `FlashRegistry` and the backbones API. -These allow you to register backbones for your task that can be selected by the user. -The backbones can come from anywhere as long as you can register a function that loads the backbone. -If you want to configure some backbones for your task, it's best practice to include these in a `backbones.py` file. -Take a look at `vision/backbones.py` for an example, and have a look at `vision/classification/model.py` to see how these can be added to your `Task`. - -### `serialization.py` - -Sometimes you want to give the user some control over their prediction format. -`Postprocess` can do the heavy lifting (anything you always want to apply to the predictions), but one or more custom `Serializer` implementations can be used to convert the predictions to a desired output format. -A good example is in classification; sometimes we'd like the classes, sometimes the logits, sometimes the labels, you get the idea. -You should add your `Serializer` implementations in a `serialization.py` file and set a good default in your `Task`. From 04b3b9662cb8b4014b2e96fb4441a6915cd2d577 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 13:00:46 +0100 Subject: [PATCH 06/53] Fixes --- docs/source/conf.py | 1 + docs/source/template/data.rst | 55 ++++++++++++++++------------- docs/source/template/optional.rst | 39 +++++++++++++++++---- flash/template/data.py | 58 +++++++++++++++---------------- 4 files changed, 93 insertions(+), 60 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 561896cf09..0f917db09a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -78,6 +78,7 @@ "torch": ("https://pytorch.org/docs/stable/", None), "numpy": ("https://docs.scipy.org/doc/numpy/", None), "PIL": ("https://pillow.readthedocs.io/en/stable/", None), + "pytorchvideo": ("https://pytorchvideo.readthedocs.io/en/latest/", None), "pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None), } diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index 3c4b19bf0d..8c036cffd1 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -7,23 +7,23 @@ The Data The first step to contributing a task is to implement the classes we need to load some data. Inside ``data.py`` you should implement: -#. one or more :class:`~flash.data.data_source.DataSource` classes -#. a :class:`~flash.data.process.Preprocess` -#. a :class:`~flash.data.data_module.DataModule` -#. a :class:`~flash.data.callbacks.BaseVisualization` *(optional)* -#. a :class:`~flash.data.process.Postprocess` *(optional)* +#. one or more :class:`~flash.core.data.data_source.DataSource` classes +#. a :class:`~flash.core.data.process.Preprocess` +#. a :class:`~flash.core.data.data_module.DataModule` +#. a :class:`~flash.core.data.callbacks.BaseVisualization` *(optional)* +#. a :class:`~flash.core.data.process.Postprocess` *(optional)* DataSource ^^^^^^^^^^ -The :class:`~flash.data.data_source.DataSource` implementations describe how data from particular sources (like folders, files, tensors, etc.) should be loaded. -At a minimum you will require one :class:`~flash.data.data_source.DataSource` implementation, but if you want to support a few different ways of loading data for your task, the more the merrier! -Each :class:`~flash.data.data_source.DataSource` has a ``load_data`` method and a ``load_sample`` method. +The :class:`~flash.core.data.data_source.DataSource` implementations describe how data from particular sources (like folders, files, tensors, etc.) should be loaded. +At a minimum you will require one :class:`~flash.core.data.data_source.DataSource` implementation, but if you want to support a few different ways of loading data for your task, the more the merrier! +Each :class:`~flash.core.data.data_source.DataSource` has a ``load_data`` method and a ``load_sample`` method. The ``load_data`` method accepts some dataset metadata (e.g. a folder name) and produces a sequence or iterable of samples or sample metadata. The ``load_sample`` method then takes as input a single element from the output of ``load_data`` and returns a sample. -By default, these methods just return their input as you will not always need both to create a functioning :class:`~flash.data.data_source.DataSource`. +By default, these methods just return their input as you will not always need both to create a functioning :class:`~flash.core.data.data_source.DataSource`. -I'ts best practice to just override one of our existing :class:`~flash.data.data_source.DataSource` classes where possible. +I'ts best practice to just override one of our existing :class:`~flash.core.data.data_source.DataSource` classes where possible. Take a look at our ``TemplateNumpyDataSource`` which does this to get started: .. autoclass:: flash.template.data.TemplateNumpyDataSource @@ -65,20 +65,25 @@ Sometimes you need to something a bit more custom, have a look at our ``Template DataSource vs Dataset ~~~~~~~~~~~~~~~~~~~~~ -A :class:`~flash.data.data_source.DataSource` is not the same as a :class:`torch.utils.data.Dataset`. -A :class:`torch.utils.data.Dataset` knows about the data, whereas a :class:`~flash.data.data_source.DataSource` only know about how to load the data. -So it's possible for a single :class:`~flash.data.data_source.DataSource` to create more than one :class:`~torch.utils.data.Dataset`. +A :class:`~flash.core.data.data_source.DataSource` is not the same as a :class:`torch.utils.data.Dataset`. +A :class:`torch.utils.data.Dataset` knows about the data, whereas a :class:`~flash.core.data.data_source.DataSource` only know about how to load the data. +So it's possible for a single :class:`~flash.core.data.data_source.DataSource` to create more than one :class:`~torch.utils.data.Dataset`. It's also fine for the output of the ``load_data`` method to just be a :class:`torch.utils.data.Dataset` instance. You don't need to re-write custom datasets, just instantiate them in ``load_data`` similarly to what we did with the ``TemplateSKLearnDataSource``. -For example, the ``load_data`` of the PyTorchVideo ``PathsDataSource`` just creates a :class:`pytorchvideo.data.encoded_video_dataset.EncodedVideoDataset` from the given folder. +For example, the ``load_data`` of the ``VideoClassificationPathsDataSource`` just creates an :class:`~pytorchvideo.data.EncodedVideoDataset` from the given folder. +Here's how it looks (from ``video/classification.data.py``): + +.. literalinclude:: ../../../flash/video/classification/data.py + :language: python + :pyobject: VideoClassificationPathsDataSource.load_data Preprocess ^^^^^^^^^^ -The :class:`~flash.data.process.Preprocess` is how all transforms are defined in Flash. -Internally we inject the :class:`~flash.data.process.Preprocess` transforms into the right places so that we can address the batch at several points along the pipeline. -Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your :class:`~flash.data.process.Preprocess` is as simple as implementing the ``default_transforms`` method. -The :class:`~flash.data.process.Preprocess` also knows about the available :class:`~flash.data.data_source.DataSource` classes that it can work with, which should be configured in the ``__init__``. +The :class:`~flash.core.data.process.Preprocess` is how all transforms are defined in Flash. +Internally we inject the :class:`~flash.core.data.process.Preprocess` transforms into the right places so that we can address the batch at several points along the pipeline. +Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your :class:`~flash.core.data.process.Preprocess` is as simple as implementing the ``default_transforms`` method. +The :class:`~flash.core.data.process.Preprocess` also knows about the available :class:`~flash.core.data.data_source.DataSource` classes that it can work with, which should be configured in the ``__init__``. Take a look at our ``TemplatePreprocess`` to get started: @@ -101,13 +106,13 @@ Take a look at our ``TemplatePreprocess`` to get started: DataModule ^^^^^^^^^^ -The :class:`~flash.data.data_module.DataModule` is where the hard work of our :class:`~flash.data.data_source.DataSource` and :class:`~flash.data.process.Preprocess` implementations pays off. -If your :class:`~flash.data.data_source.DataSource` implementation(s) conform to our :class:`~flash.data.data_source.DefaultDataSources` (e.g. ``DefaultDataSources.FOLDERS``) then your :class:`~flash.data.data_module.DataModule` implementation simply needs a ``preprocess_cls`` attribute. -You now have a :class:`~flash.data.data_module.DataModule` that can be instantiated with ``from_*`` for whichever data sources you have configured (e.g. ``MyDataModule.from_folders``). +The :class:`~flash.core.data.data_module.DataModule` is where the hard work of our :class:`~flash.core.data.data_source.DataSource` and :class:`~flash.core.data.process.Preprocess` implementations pays off. +If your :class:`~flash.core.data.data_source.DataSource` implementation(s) conform to our :class:`~flash.core.data.data_source.DefaultDataSources` (e.g. ``DefaultDataSources.FOLDERS``) then your :class:`~flash.core.data.data_module.DataModule` implementation simply needs a ``preprocess_cls`` attribute. +You now have a :class:`~flash.core.data.data_module.DataModule` that can be instantiated with ``from_*`` for whichever data sources you have configured (e.g. ``MyDataModule.from_folders``). It also includes all of your default transforms! -If you've defined a fully custom :class:`~flash.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need a ``from_*`` method for each (we'll define ``from_sklearn`` for our example). -The ``from_*`` methods take whatever arguments you want them to and call :meth:`~flash.data.data_module.DataModule.from_data_source` with the name given to your custom data source in the ``Preprocess.__init__``. +If you've defined a fully custom :class:`~flash.core.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need a ``from_*`` method for each (we'll define ``from_sklearn`` for our example). +The ``from_*`` methods take whatever arguments you want them to and call :meth:`~flash.core.data.data_module.DataModule.from_data_source` with the name given to your custom data source in the ``Preprocess.__init__``. Take a look at our ``TemplateData`` to get started: @@ -159,8 +164,8 @@ Postprocess ^^^^^^^^^^^ Sometimes you have some transforms that need to be applied _after_ your model. -For this you can optionally implement a :class:`~flash.data.process.Postprocess`. -The :class:`~flash.data.process.Postprocess` is applied to the model outputs during inference. +For this you can optionally implement a :class:`~flash.core.data.process.Postprocess`. +The :class:`~flash.core.data.process.Postprocess` is applied to the model outputs during inference. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. ------ diff --git a/docs/source/template/optional.rst b/docs/source/template/optional.rst index 0593b42c90..bbb9396ec0 100644 --- a/docs/source/template/optional.rst +++ b/docs/source/template/optional.rst @@ -8,18 +8,45 @@ transforms.py ============= Sometimes you'd like to have quite a few transforms by default (standard augmentations, normalization, etc.). -If you do then, for better organization, you can define a `transforms.py` which houses your default transforms to be referenced in your `Preprocess`. -Take a look at `vision/classification/transforms.py` for an example. +If you do then, for better organization, you can define a ``transforms.py`` which houses your default transforms to be referenced in your :class:`~flash.core.data.process.Preprocess`. +Here's an example from ``image/classification/transforms.py`` which creates some default transforms given the desired image size: + +.. literalinclude:: ../../../flash/image/classification/transforms.py + :language: python + :pyobject: default_transforms + +Here's how we create our transforms in the :class:`~flash.image.classification.data.ImageClassificationPreprocess`: + +.. literalinclude:: ../../../flash/image/classification/data.py + :language: python + :pyobject: ImageClassificationPreprocess.default_transforms backbones.py ============ In Flash, we love to provide as much access to the state-of-the-art as we can. -To this end, we've created the `FlashRegistry` and the backbones API. -These allow you to register backbones for your task that can be selected by the user. +To this end, we've created the :any:`FlashRegistry `. +The registry allows you to register backbones for your task that can be selected by the user. The backbones can come from anywhere as long as you can register a function that loads the backbone. -If you want to configure some backbones for your task, it's best practice to include these in a `backbones.py` file. -Take a look at `vision/backbones.py` for an example, and have a look at `vision/classification/model.py` to see how these can be added to your `Task`. +If you want to configure some backbones for your task, it's best practice to include these in a ``backbones.py`` file. +Here's an example adding ``SimCLR`` to the ``IMAGE_CLASSIFIER_BACKBONES``, from ``image/backbones.py``: + +.. literalinclude:: ../../../flash/image/backbones.py + :language: python + :pyobject: load_simclr_imagenet + +In ``image/classification/model.py``, we attach ``IMAGE_CLASSIFIER_BACKBONES`` to the :class:`~flash.image.classification.model.ImageClassifier` as a class attribute ``backbones``. +Now we get the backbone from the registry and create a head in the ``__init__``: + +.. literalinclude:: ../../../flash/image/classification/model.py + :language: python + :pyobject: ImageClassifier.__init__ + +Finally, we use our backbone and head in a custom forward pass: + +.. literalinclude:: ../../../flash/image/classification/model.py + :language: python + :pyobject: ImageClassifier.forward serialization.py ================ diff --git a/flash/template/data.py b/flash/template/data.py index da5de4d678..88c4bbe903 100644 --- a/flash/template/data.py +++ b/flash/template/data.py @@ -29,11 +29,11 @@ class TemplateNumpyDataSource(NumpyDataSource): """An example data source that records ``num_features`` on the dataset. We extend - :class:`~flash.data.data_source.NumpyDataSource` so that we can use ``super().load_data``.""" + :class:`~flash.core.data.data_source.NumpyDataSource` so that we can use ``super().load_data``.""" def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]: - """The main :class:`~flash.data.data_source.DataSource` method that we have to implement is - :meth:`~flash.data.data_source.DataSource.load_data`. As we're extending the ``NumpyDataSource``, we expect the + """The main :class:`~flash.core.data.data_source.DataSource` method that we have to implement is + :meth:`~flash.core.data.data_source.DataSource.load_data`. As we're extending the ``NumpyDataSource``, we expect the same ``data`` argument (in this case, a tuple containing data and corresponding target arrays). We can also take the dataset argument. Any attributes we set on ``dataset`` will be available on the ``Dataset`` @@ -57,12 +57,12 @@ class TemplateSKLearnDataSource(TemplateNumpyDataSource): """An example data source that loads data from an sklearn data ``Bunch``.""" def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]: - """Here we're creating a fully custom :class:`~flash.data.data_source.DataSource` (that is, we're not going to - treat it as one of the :class:`~flash.data.data_source.DefaultDataSources`) so the type of the ``data`` argument + """Here we're creating a fully custom :class:`~flash.core.data.data_source.DataSource` (that is, we're not going to + treat it as one of the :class:`~flash.core.data.data_source.DefaultDataSources`) so the type of the ``data`` argument is up to us. In this case, we want to be able to use a scikit-learn data ``Bunch`` as an input. On our dataset, we'll set the ``num_classes`` attribute. This is a standard practice in our classification tasks - and ``num_classes`` will automatically be made available by the :class:`~flash.data.data_module.DataModule`. + and ``num_classes`` will automatically be made available by the :class:`~flash.core.data.data_module.DataModule`. Args: data: The scikit-learn data ``Bunch``. @@ -83,20 +83,20 @@ def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]: class TemplatePreprocess(Preprocess): - """The next thing for us to implement is the :class:`~flash.data.process.Preprocess`. The - :class:`~flash.data.process.Preprocess` must take ``train_transform``, ``val_transform``, ``test_transform``, and + """The next thing for us to implement is the :class:`~flash.core.data.process.Preprocess`. The + :class:`~flash.core.data.process.Preprocess` must take ``train_transform``, ``val_transform``, ``test_transform``, and ``predict_transform`` arguments in the ``__init__``. Any additional arguments for it to take are up to you. Inside the ``__init__``, we make a call to super. This is where we register our data sources. Data sources should be given as a dictionary which maps data source name to data source object. The name can be anything, but if you want to take advantage of our built-in ``from_*`` classmethods, you should use - :class:`~flash.data.data_source.DefaultDataSources` as the names. In our case, we have both a - :attr:`~flash.data.data_source.DefaultDataSources.NUMPY` and a custom scikit-learn data source (which we'll call + :class:`~flash.core.data.data_source.DefaultDataSources` as the names. In our case, we have both a + :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` and a custom scikit-learn data source (which we'll call "sklearn"). We can also provide a ``default_data_source``. This is the name of the data source to use by default when predicting. It'd be cool if we could get predictions just from a numpy array, so let's use the - :attr:`~flash.data.data_source.DefaultDataSources.NUMPY` as the default. + :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` as the default. Args: train_transform: The user-specified transforms to apply during training. @@ -139,12 +139,12 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) def default_transforms(self) -> Optional[Dict[str, Callable]]: - """Your :class:`~flash.data.process.Preprocess` should usually define some default transforms. Generally, we at + """Your :class:`~flash.core.data.process.Preprocess` should usually define some default transforms. Generally, we at least want to convert to a tensor, so let's do that here. Our inputs samples will be dictionaries whose keys are from the - :class:`~flash.data.data_source.DefaultDataKeys`, so we need to map each key to different transforms using - :class:`~flash.data.transforms.ApplyToKeys`. By convention, we apply sequences of transforms by wrapping them in + :class:`~flash.core.data.data_source.DefaultDataKeys`, so we need to map each key to different transforms using + :class:`~flash.core.data.transforms.ApplyToKeys`. By convention, we apply sequences of transforms by wrapping them in an ``nn.Sequential``. Returns: @@ -163,7 +163,7 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: class TemplateData(DataModule): - """Creating our :class:`~flash.data.data_module.DataModule` is as easy as setting the ``preprocess_cls`` attribute. + """Creating our :class:`~flash.core.data.data_module.DataModule` is as easy as setting the ``preprocess_cls`` attribute. We get the ``from_numpy`` method for free as we've configured a ``DefaultDataSources.NUMPY`` data source. We'll also add a ``from_sklearn`` method so that we can use our ``TemplateSKLearnDataSource. Finally, we define the ``num_features`` property for convenience. @@ -190,9 +190,9 @@ def from_sklearn( **preprocess_kwargs: Any, ): """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them - through to the :meth:`~flash.data.data_module.DataModule.from_data_source` method underneath. It's really just + through to the :meth:`~flash.core.data.data_module.DataModule.from_data_source` method underneath. It's really just a convenience method to save the user from needing to call - :meth:`~flash.data.data_module.DataModule.from_data_source` directly. + :meth:`~flash.core.data.data_module.DataModule.from_data_source` directly. Args: train_bunch: The scikit-learn ``Bunch`` containing the train data. @@ -200,21 +200,21 @@ def from_sklearn( test_bunch: The scikit-learn ``Bunch`` containing the test data. predict_bunch: The scikit-learn ``Bunch`` containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.data.process.Preprocess` hook names to callable transforms. - data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.data.data_module.DataModule`. - preprocess: The :class:`~flash.data.data.Preprocess` to pass to the - :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed and used. - val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. - batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. - num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -256,7 +256,7 @@ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: class TemplateVisualization(BaseVisualization): - """The ``TemplateVisualization`` class is a :class:`~flash.data.callbacks.BaseVisualization` that just prints the + """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just prints the data. If you want to provide a visualization with your task, you can override these hooks.""" def show_load_sample(self, samples: List[Any], running_stage: RunningStage): From f79f9095c6d8e8e67742db859c70bb2ffaa7e190 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 13:19:41 +0100 Subject: [PATCH 07/53] Updates --- docs/source/template/optional.rst | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/source/template/optional.rst b/docs/source/template/optional.rst index bbb9396ec0..a34f5d194f 100644 --- a/docs/source/template/optional.rst +++ b/docs/source/template/optional.rst @@ -52,6 +52,23 @@ serialization.py ================ Sometimes you want to give the user some control over their prediction format. -`Postprocess` can do the heavy lifting (anything you always want to apply to the predictions), but one or more custom `Serializer` implementations can be used to convert the predictions to a desired output format. -A good example is in classification; sometimes we'd like the classes, sometimes the logits, sometimes the labels, you get the idea. -You should add your `Serializer` implementations in a `serialization.py` file and set a good default in your `Task`. +:class:`~flash.core.data.process.Postprocess` can do the heavy lifting (anything you always want to apply to the predictions), but one or more custom :class:`~flash.core.data.process.Serializer` implementations can be used to convert the predictions to a desired output format. +You should add your :class:`~flash.core.data.process.Serializer` implementations in a ``serialization.py`` file and set a good default in your :class:`~flash.core.model.Task`. +Some good examples are in ``core/classification.py``. +Here's the :class:`~flash.core.classification.Classes` :class:`~flash.core.data.process.Serializer`: + +.. literalinclude:: ../../../flash/core/classification.py + :language: python + :pyobject: Classes + +Alternatively, here's the :class:`~flash.core.classification.Logits` :class:`~flash.core.data.process.Serializer`: + +.. literalinclude:: ../../../flash/core/classification.py + :language: python + :pyobject: Logits + +Take a look at :ref:`predictions` to learn more. + +------ + +Once you've added any optional extras, it's time to :ref:`create some examples showing your task in action! ` From 9834a47beb319585406ac7e05f196eb959e6951a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 14:33:45 +0100 Subject: [PATCH 08/53] Add examples --- docs/source/template/examples.rst | 52 ++++++++++++++++++- flash/template/__init__.py | 2 + flash/template/data.py | 20 ++++++- .../flash_examples/finetuning/template.py | 52 +++++++++++++++++++ .../flash_examples/predict/template.py | 38 ++++++++++++++ flash/template/model.py | 5 +- 6 files changed, 164 insertions(+), 5 deletions(-) create mode 100644 flash/template/flash_examples/finetuning/template.py create mode 100644 flash/template/flash_examples/predict/template.py diff --git a/docs/source/template/examples.rst b/docs/source/template/examples.rst index 93631cfc13..acb52b5767 100644 --- a/docs/source/template/examples.rst +++ b/docs/source/template/examples.rst @@ -5,12 +5,60 @@ The Examples ************ Now you've implemented your task, it's time to add some examples showing how cool it is! -We usually provide one finetuning example (in `flash_examples/finetuning`) and one predict / inference example (in `flash_examples/predict`). -You can base these off of our `template.py` examples. +We usually provide one finetuning example in ``flash_examples/finetuning`` and one predict / inference example in ``flash_examples/predict``. +You can base these off of our ``template.py`` examples. Let's take a closer look. finetuning ========== +The finetuning example should: + +#. download the data +#. load the data into a :class:`~flash.core.data.data_module.DataModule` +#. create an instance of the :class:`~flash.core.model.Task` +#. create a :class:`~flash.core.trainer.Trainer` +#. call :meth:`~flash.core.trainer.Trainer.finetune` or :meth:`~flash.core.trainer.Trainer.fit` to train your model +#. save the checkpoint +#. generate predictions for a few examples *(optional)* + +For our template example we don't have a pretrained backbone, so we can just call :meth:`~flash.core.trainer.Trainer.fit` rather than :meth:`~flash.core.trainer.Trainer.finetune`. +Here's the full example: + +.. literalinclude:: ../../../flash/template/flash_examples/finetuning/template.py + :language: python + :lines: 14- + +We get this output: + +.. code-block:: + + ['setosa', 'virginica', 'versicolor'] + predict ======= + +The predict example should: + +#. download the data +#. load an instance of the :class:`~flash.core.model.Task` from a checkpoint stored on `S3` (speak with one of us about getting your checkpoint hosted) +#. generate predictions for a few examples +#. generate predictions for a whole dataset, folder, etc. + +For our template example we don't have a pretrained backbone, so we can just call :meth:`~flash.core.trainer.Trainer.fit` rather than :meth:`~flash.core.trainer.Trainer.finetune`. +Here's the full example: + +.. literalinclude:: ../../../flash/template/flash_examples/predict/template.py + :language: python + :lines: 14- + +We get this output: + +.. code-block:: + + ['setosa', 'virginica', 'versicolor'] + [['setosa', 'setosa', 'setosa', 'setosa'], ..., ['virginica', 'virginica']] + +------ + +Now that you've got some examples showing your awesome task in action, it's time to :ref:`write some tests! ` diff --git a/flash/template/__init__.py b/flash/template/__init__.py index e69de29bb2..d0d1c5b3be 100644 --- a/flash/template/__init__.py +++ b/flash/template/__init__.py @@ -0,0 +1,2 @@ +from flash.template.data import TemplateData +from flash.template.model import TemplateSKLearnClassifier diff --git a/flash/template/data.py b/flash/template/data.py index 88c4bbe903..73788762f2 100644 --- a/flash/template/data.py +++ b/flash/template/data.py @@ -81,6 +81,19 @@ def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]: # Now we just call super with the data and targets return super().load_data((data.data, data.target), dataset=dataset) + def predict_load_data(self, data: Bunch) -> Sequence[Mapping[str, Any]]: + """You can prepend ``train``, ``val``, ``test``, or ``predict`` to ``load_data`` in order to customize the + behaviour for a particular stage. In this case, we implement ``predict_load_data`` to avoid including targets + when predicting. + + Args: + data: The scikit-learn data ``Bunch``. + + Returns: + A sequence of samples / sample metadata. + """ + return super().predict_load_data(data.data) + class TemplatePreprocess(Preprocess): """The next thing for us to implement is the :class:`~flash.core.data.process.Preprocess`. The @@ -138,6 +151,11 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): """ return cls(**state_dict) + @staticmethod + def input_to_tensor(input: np.ndarray): + """Transform which creates a tensor from the given numpy ``ndarray`` and converts it to ``float``""" + return torch.from_numpy(input).float() + def default_transforms(self) -> Optional[Dict[str, Callable]]: """Your :class:`~flash.core.data.process.Preprocess` should usually define some default transforms. Generally, we at least want to convert to a tensor, so let's do that here. @@ -152,7 +170,7 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: """ return { "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torch.from_numpy), + ApplyToKeys(DefaultDataKeys.INPUT, self.input_to_tensor), ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), } diff --git a/flash/template/flash_examples/finetuning/template.py b/flash/template/flash_examples/finetuning/template.py new file mode 100644 index 0000000000..f6e0f9fbbb --- /dev/null +++ b/flash/template/flash_examples/finetuning/template.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from sklearn import datasets + +import flash +from flash.core.classification import Labels +from flash.template import TemplateData, TemplateSKLearnClassifier + +# 1. Download the data +data_bunch = datasets.load_iris() + +# 2. Load the data +datamodule = TemplateData.from_sklearn( + train_bunch=data_bunch, + val_split=0.8, +) + +# 3. Build the model +model = TemplateSKLearnClassifier( + num_features=datamodule.num_features, + num_classes=datamodule.num_classes, + serializer=Labels(), +) + +# 4. Create the trainer. +trainer = flash.Trainer(max_epochs=20) + +# 5. Train the model +trainer.fit(model, datamodule=datamodule) + +# 6. Save it! +trainer.save_checkpoint("template_model.pt") + +# 7. Classify a few examples +predictions = model.predict([ + np.array([4.9, 3.0, 1.4, 0.2]), + np.array([6.9, 3.2, 5.7, 2.3]), + np.array([7.2, 3.0, 5.8, 1.6]), +]) +print(predictions) diff --git a/flash/template/flash_examples/predict/template.py b/flash/template/flash_examples/predict/template.py new file mode 100644 index 0000000000..414b10633f --- /dev/null +++ b/flash/template/flash_examples/predict/template.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from sklearn import datasets + +from flash import Trainer +from flash.template import TemplateData, TemplateSKLearnClassifier + +# 1. Download the data +data_bunch = datasets.load_iris() + +# 2. Load the model from a checkpoint +model = TemplateSKLearnClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/template_model.pt") + +# 3. Predict what's on a few images! ants or bees? +predictions = model.predict([ + np.array([4.9, 3.0, 1.4, 0.2]), + np.array([6.9, 3.2, 5.7, 2.3]), + np.array([7.2, 3.0, 5.8, 1.6]), +]) +print(predictions) + +# 4. Or generate predictions from a whole dataset! +datamodule = TemplateData.from_sklearn(predict_bunch=data_bunch) + +predictions = Trainer().predict(model, datamodule=datamodule) +print(predictions) diff --git a/flash/template/model.py b/flash/template/model.py index d8cbbe1821..a2b42fafd8 100644 --- a/flash/template/model.py +++ b/flash/template/model.py @@ -49,20 +49,21 @@ def __init__( self, num_features: int, num_classes: int, - hidden_size: 128, + hidden_size: int = 128, loss_fn: Optional[Callable] = None, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - learning_rate: float = 1e-3, + learning_rate: float = 1e-2, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): model = nn.Sequential( nn.Linear(num_features, hidden_size), nn.ReLU(True), + nn.BatchNorm1d(hidden_size), nn.Linear(hidden_size, num_classes), ) From 53a1ba292a9650fe392815f2872d310bf1d710a7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 14:46:16 +0100 Subject: [PATCH 09/53] Updates --- docs/source/template/data.rst | 1 - .../flash_examples => flash_examples}/finetuning/template.py | 0 .../flash_examples => flash_examples}/predict/template.py | 0 3 files changed, 1 deletion(-) rename {flash/template/flash_examples => flash_examples}/finetuning/template.py (100%) rename {flash/template/flash_examples => flash_examples}/predict/template.py (100%) diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index 8c036cffd1..3106a444b3 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -114,7 +114,6 @@ It also includes all of your default transforms! If you've defined a fully custom :class:`~flash.core.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need a ``from_*`` method for each (we'll define ``from_sklearn`` for our example). The ``from_*`` methods take whatever arguments you want them to and call :meth:`~flash.core.data.data_module.DataModule.from_data_source` with the name given to your custom data source in the ``Preprocess.__init__``. - Take a look at our ``TemplateData`` to get started: .. autoclass:: flash.template.data.TemplateData diff --git a/flash/template/flash_examples/finetuning/template.py b/flash_examples/finetuning/template.py similarity index 100% rename from flash/template/flash_examples/finetuning/template.py rename to flash_examples/finetuning/template.py diff --git a/flash/template/flash_examples/predict/template.py b/flash_examples/predict/template.py similarity index 100% rename from flash/template/flash_examples/predict/template.py rename to flash_examples/predict/template.py From 2694f46c4e9650a584516a665e89b4e0560c83ba Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 14:47:00 +0100 Subject: [PATCH 10/53] Updates --- docs/source/template/examples.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/template/examples.rst b/docs/source/template/examples.rst index acb52b5767..489d458be0 100644 --- a/docs/source/template/examples.rst +++ b/docs/source/template/examples.rst @@ -25,7 +25,7 @@ The finetuning example should: For our template example we don't have a pretrained backbone, so we can just call :meth:`~flash.core.trainer.Trainer.fit` rather than :meth:`~flash.core.trainer.Trainer.finetune`. Here's the full example: -.. literalinclude:: ../../../flash/template/flash_examples/finetuning/template.py +.. literalinclude:: ../../../flash_examples/finetuning/template.py :language: python :lines: 14- @@ -48,7 +48,7 @@ The predict example should: For our template example we don't have a pretrained backbone, so we can just call :meth:`~flash.core.trainer.Trainer.fit` rather than :meth:`~flash.core.trainer.Trainer.finetune`. Here's the full example: -.. literalinclude:: ../../../flash/template/flash_examples/predict/template.py +.. literalinclude:: ../../../flash_examples/predict/template.py :language: python :lines: 14- From 28b5eece1d6f40e3e7622e41a0cda27e9c50d132 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 14:49:45 +0100 Subject: [PATCH 11/53] Updates --- flash_examples/finetuning/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_examples/finetuning/template.py b/flash_examples/finetuning/template.py index f6e0f9fbbb..fcd35d6c35 100644 --- a/flash_examples/finetuning/template.py +++ b/flash_examples/finetuning/template.py @@ -35,7 +35,7 @@ ) # 4. Create the trainer. -trainer = flash.Trainer(max_epochs=20) +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) # 5. Train the model trainer.fit(model, datamodule=datamodule) From c552635ce96524e8f2cbd4ca4b6576eb49f08d1c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 15:48:07 +0100 Subject: [PATCH 12/53] Updates --- docs/source/template/data.rst | 26 +++++++++---------- docs/source/template/examples.rst | 4 +-- docs/source/template/intro.rst | 20 +++++++++----- docs/source/template/model.rst | 4 +-- flash/template/__init__.py | 3 +-- flash/template/classification/__init__.py | 2 ++ flash/template/{ => classification}/data.py | 0 flash/template/{ => classification}/model.py | 0 .../predict/image_classification.py | 3 +++ 9 files changed, 36 insertions(+), 26 deletions(-) create mode 100644 flash/template/classification/__init__.py rename flash/template/{ => classification}/data.py (100%) rename flash/template/{ => classification}/model.py (100%) diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index 3106a444b3..d989f0b5dd 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -21,12 +21,12 @@ At a minimum you will require one :class:`~flash.core.data.data_source.DataSourc Each :class:`~flash.core.data.data_source.DataSource` has a ``load_data`` method and a ``load_sample`` method. The ``load_data`` method accepts some dataset metadata (e.g. a folder name) and produces a sequence or iterable of samples or sample metadata. The ``load_sample`` method then takes as input a single element from the output of ``load_data`` and returns a sample. -By default, these methods just return their input as you will not always need both to create a functioning :class:`~flash.core.data.data_source.DataSource`. +By default, these methods just return their input, you will not always need both methods to create :class:`~flash.core.data.data_source.DataSource`. -I'ts best practice to just override one of our existing :class:`~flash.core.data.data_source.DataSource` classes where possible. -Take a look at our ``TemplateNumpyDataSource`` which does this to get started: +It's best practice to just override one of our existing :class:`~flash.core.data.data_source.DataSource` classes where possible. +Take a look at our ``TemplateNumpyDataSource`` (which does this) to get started: -.. autoclass:: flash.template.data.TemplateNumpyDataSource +.. autoclass:: flash.template.classification.data.TemplateNumpyDataSource :members: .. raw:: html @@ -34,7 +34,7 @@ Take a look at our ``TemplateNumpyDataSource`` which does this to get started:
Source -.. literalinclude:: ../../../flash/template/data.py +.. literalinclude:: ../../../flash/template/classification/data.py :language: python :pyobject: TemplateNumpyDataSource @@ -46,7 +46,7 @@ Take a look at our ``TemplateNumpyDataSource`` which does this to get started: Sometimes you need to something a bit more custom, have a look at our ``TemplateSKLearnDataSource`` for an example: -.. autoclass:: flash.template.data.TemplateSKLearnDataSource +.. autoclass:: flash.template.classification.data.TemplateSKLearnDataSource :members: .. raw:: html @@ -54,7 +54,7 @@ Sometimes you need to something a bit more custom, have a look at our ``Template
Source -.. literalinclude:: ../../../flash/template/data.py +.. literalinclude:: ../../../flash/template/classification/data.py :language: python :pyobject: TemplateSKLearnDataSource @@ -87,7 +87,7 @@ The :class:`~flash.core.data.process.Preprocess` also knows about the available Take a look at our ``TemplatePreprocess`` to get started: -.. autoclass:: flash.template.data.TemplatePreprocess +.. autoclass:: flash.template.classification.data.TemplatePreprocess :members: .. raw:: html @@ -95,7 +95,7 @@ Take a look at our ``TemplatePreprocess`` to get started:
Source -.. literalinclude:: ../../../flash/template/data.py +.. literalinclude:: ../../../flash/template/classification/data.py :language: python :pyobject: TemplatePreprocess @@ -116,7 +116,7 @@ The ``from_*`` methods take whatever arguments you want them to and call :meth:` Take a look at our ``TemplateData`` to get started: -.. autoclass:: flash.template.data.TemplateData +.. autoclass:: flash.template.classification.data.TemplateData :members: .. raw:: html @@ -124,7 +124,7 @@ Take a look at our ``TemplateData`` to get started:
Source -.. literalinclude:: ../../../flash/template/data.py +.. literalinclude:: ../../../flash/template/classification/data.py :language: python :pyobject: TemplateData @@ -143,7 +143,7 @@ Take a look at our ``TemplateVisualization`` to get started: .. note:: Don't worry about implementing it right away, you can always come back and add it later! -.. autoclass:: flash.template.data.TemplateVisualization +.. autoclass:: flash.template.classification.data.TemplateVisualization :members: .. raw:: html @@ -151,7 +151,7 @@ Take a look at our ``TemplateVisualization`` to get started:
Source -.. literalinclude:: ../../../flash/template/data.py +.. literalinclude:: ../../../flash/template/classification/data.py :language: python :pyobject: TemplateVisualization diff --git a/docs/source/template/examples.rst b/docs/source/template/examples.rst index 489d458be0..eaf6d328e9 100644 --- a/docs/source/template/examples.rst +++ b/docs/source/template/examples.rst @@ -23,7 +23,7 @@ The finetuning example should: #. generate predictions for a few examples *(optional)* For our template example we don't have a pretrained backbone, so we can just call :meth:`~flash.core.trainer.Trainer.fit` rather than :meth:`~flash.core.trainer.Trainer.finetune`. -Here's the full example: +Here's the full example (``flash_examples/finetuning/template.py``): .. literalinclude:: ../../../flash_examples/finetuning/template.py :language: python @@ -46,7 +46,7 @@ The predict example should: #. generate predictions for a whole dataset, folder, etc. For our template example we don't have a pretrained backbone, so we can just call :meth:`~flash.core.trainer.Trainer.fit` rather than :meth:`~flash.core.trainer.Trainer.finetune`. -Here's the full example: +Here's the full example (``flash_examples/predict/template.py``): .. literalinclude:: ../../../flash_examples/predict/template.py :language: python diff --git a/docs/source/template/intro.rst b/docs/source/template/intro.rst index 4dc097b3d0..c0997dedac 100644 --- a/docs/source/template/intro.rst +++ b/docs/source/template/intro.rst @@ -1,23 +1,29 @@ .. _contributing: -************ -Introduction -************ +********************* +Introduction / Set-up +********************* Welcome ======= -Before you begin, we'd like to express our sincere gratitude to you for wanting to add a task to Flash. +Before you begin, we'd like to express our gratitude to you for wanting to add a task to Flash. With Flash our aim is to create a great user experience, enabling awesome advanced applications with just a few lines of code. We're really pleased with what we've achieved with Flash and we hope you will be too. Now let's dive in! -Tutorials -========= +Set-up +====== The Task template is designed to guide you through contributing a task to Flash. It contains the code, tests, and examples for a task that performs classification with a multi-layer perceptron, intended for use with the classic data sets from scikit-learn. -You should copy the files in ``flash/template`` to get started. +Before you begin, copy the files in ``flash/template/classification`` to the location where you are implementing your task. +Our folders are organised in terms of data-type (image, text, video, etc.), with sub-folders for different task types (classification, regression, etc.). +If a data-type folder already exists for your task, then a task type sub-folder should be added containing the template files. +If a data-type folder doesn't exist, then you will need to add that too. + +Tutorials +========= The tutorials in this section will walk you through all of the components you need to implement (or adapt from the template) for your custom task. diff --git a/docs/source/template/model.rst b/docs/source/template/model.rst index 6b0dc371e7..816c6c8dac 100644 --- a/docs/source/template/model.rst +++ b/docs/source/template/model.rst @@ -19,7 +19,7 @@ In our template example, we just extract the input and target from the input map Here's our ``TemplateSKLearnClassifier``: -.. autoclass:: flash.template.model.TemplateSKLearnClassifier +.. autoclass:: flash.template.classification.model.TemplateSKLearnClassifier :members: .. raw:: html @@ -27,7 +27,7 @@ Here's our ``TemplateSKLearnClassifier``:
Source -.. literalinclude:: ../../../flash/template/model.py +.. literalinclude:: ../../../flash/template/classification/model.py :language: python :pyobject: TemplateSKLearnClassifier diff --git a/flash/template/__init__.py b/flash/template/__init__.py index d0d1c5b3be..2de73a02c9 100644 --- a/flash/template/__init__.py +++ b/flash/template/__init__.py @@ -1,2 +1 @@ -from flash.template.data import TemplateData -from flash.template.model import TemplateSKLearnClassifier +from flash.template.classification import TemplateData, TemplateSKLearnClassifier diff --git a/flash/template/classification/__init__.py b/flash/template/classification/__init__.py new file mode 100644 index 0000000000..a1958a126a --- /dev/null +++ b/flash/template/classification/__init__.py @@ -0,0 +1,2 @@ +from flash.template.classification.data import TemplateData +from flash.template.classification.model import TemplateSKLearnClassifier diff --git a/flash/template/data.py b/flash/template/classification/data.py similarity index 100% rename from flash/template/data.py rename to flash/template/classification/data.py diff --git a/flash/template/model.py b/flash/template/classification/model.py similarity index 100% rename from flash/template/model.py rename to flash/template/classification/model.py diff --git a/flash_examples/predict/image_classification.py b/flash_examples/predict/image_classification.py index 19d010e6e5..0c07a278be 100644 --- a/flash_examples/predict/image_classification.py +++ b/flash_examples/predict/image_classification.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from flash import Trainer +from flash.core.classification import Probabilities from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier @@ -22,6 +23,8 @@ model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") # 3a. Predict what's on a few images! ants or bees? + +model.serializer = Probabilities() predictions = model.predict([ "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", From 65f9bdd4867da4d13db91d411cc1fdde4df41f06 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 17:06:49 +0100 Subject: [PATCH 13/53] Add tests --- docs/source/template/intro.rst | 1 + tests/template/classification/__init__.py | 0 tests/template/classification/test_data.py | 118 ++++++++++++++++++++ tests/template/classification/test_model.py | 100 +++++++++++++++++ 4 files changed, 219 insertions(+) create mode 100644 tests/template/classification/__init__.py create mode 100644 tests/template/classification/test_data.py create mode 100644 tests/template/classification/test_model.py diff --git a/docs/source/template/intro.rst b/docs/source/template/intro.rst index c0997dedac..39f016a02a 100644 --- a/docs/source/template/intro.rst +++ b/docs/source/template/intro.rst @@ -21,6 +21,7 @@ Before you begin, copy the files in ``flash/template/classification`` to the loc Our folders are organised in terms of data-type (image, text, video, etc.), with sub-folders for different task types (classification, regression, etc.). If a data-type folder already exists for your task, then a task type sub-folder should be added containing the template files. If a data-type folder doesn't exist, then you will need to add that too. +You should also copy the files from ``tests/template/classification`` to the corresponding data-type, task type folder in ``tests``. Tutorials ========= diff --git a/tests/template/classification/__init__.py b/tests/template/classification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/template/classification/test_data.py b/tests/template/classification/test_data.py new file mode 100644 index 0000000000..d85fc525ce --- /dev/null +++ b/tests/template/classification/test_data.py @@ -0,0 +1,118 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _SKLEARN_AVAILABLE +from flash.template.classification.data import TemplateData, TemplatePreprocess + +if _SKLEARN_AVAILABLE: + from sklearn import datasets + + +class TestTemplatePreprocess: + """Tests ``TemplatePreprocess``.""" + + def test_smoke(self): + """A simple test that the class can be instantiated.""" + prep = TemplatePreprocess() + assert prep is not None + + +@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed.") +class TestTemplateData: + """Tests ``TemplateData``.""" + + num_classes: int = 3 + num_features: int = 4 + + def test_smoke(self): + """A simple test that the class can be instantiated.""" + dm = TemplateData() + assert dm is not None + + def test_from_numpy(self): + """Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method.""" + data = np.random.rand(10, self.num_features) + targets = np.random.randint(0, self.num_classes, (10, )) + + # instantiate the data module + dm = TemplateData.from_numpy( + train_data=data, + train_targets=targets, + val_data=data, + val_targets=targets, + test_data=data, + test_targets=targets, + batch_size=2, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert rows.shape == (2, self.num_features) + assert targets.shape == (2, ) + + # check val data + data = next(iter(dm.val_dataloader())) + rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert rows.shape == (2, self.num_features) + assert targets.shape == (2, ) + + # check test data + data = next(iter(dm.test_dataloader())) + rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert rows.shape == (2, self.num_features) + assert targets.shape == (2, ) + + def test_from_sklearn(self): + """Tests that ``TemplateData`` is properly created when using the ``from_sklearn`` method.""" + data = datasets.load_iris() + + # instantiate the data module + dm = TemplateData.from_sklearn( + train_bunch=data, + val_bunch=data, + test_bunch=data, + batch_size=2, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert rows.shape == (2, dm.num_features) + assert targets.shape == (2, ) + + # check val data + data = next(iter(dm.val_dataloader())) + rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert rows.shape == (2, dm.num_features) + assert targets.shape == (2, ) + + # check test data + data = next(iter(dm.test_dataloader())) + rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert rows.shape == (2, dm.num_features) + assert targets.shape == (2, ) diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py new file mode 100644 index 0000000000..afe4fee696 --- /dev/null +++ b/tests/template/classification/test_model.py @@ -0,0 +1,100 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +import torch + +from flash import Trainer +from flash.core.data.data_pipeline import DataPipeline +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _SKLEARN_AVAILABLE +from flash.template import TemplateSKLearnClassifier +from flash.template.classification.data import TemplatePreprocess + +if _SKLEARN_AVAILABLE: + from sklearn import datasets + +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + """We create one or more ``DummyDataset`` classes to provide random data to the model for testing.""" + + num_classes: int = 3 + num_features: int = 4 + + def __getitem__(self, index): + return { + DefaultDataKeys.INPUT: torch.randn(self.num_features), + DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, (1, )), + } + + def __len__(self) -> int: + return 10 + + +# ============================== + + +@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed.") +def test_smoke(): + """A simple test that the class can be instantiated.""" + model = TemplateSKLearnClassifier(num_features=1, num_classes=1) + assert model is not None + + +@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed.") +@pytest.mark.parametrize("num_classes", [4, 256]) +@pytest.mark.parametrize("shape", [(1, 3), (2, 128)]) +def test_forward(num_classes, shape): + """Tests that a tensor can be given to the model forward and gives the correct output size.""" + model = TemplateSKLearnClassifier( + num_features=shape[1], + num_classes=num_classes, + ) + model.eval() + + row = torch.rand(*shape) + + out = model(row) + assert out.shape == (shape[0], num_classes) + + +@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed.") +def test_init_train(tmpdir): + """Tests that the model can be trained on our ``DummyDataset``.""" + model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) + train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_dl) + + +@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed.") +def test_predict_numpy(): + """Tests that we can generate predictions from a numpy array.""" + row = np.random.rand(1, DummyDataset.num_features) + model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) + data_pipe = DataPipeline(preprocess=TemplatePreprocess()) + out = model.predict(row, data_pipeline=data_pipe) + assert isinstance(out[0], int) + + +@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed.") +def test_predict_bunch(): + """Tests that we can generate predictions from a scikit-learn ``Bunch``.""" + bunch = datasets.load_iris() + model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) + data_pipe = DataPipeline(preprocess=TemplatePreprocess()) + out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe) + assert isinstance(out[0], int) From cc3001ad22af02b176d9f2267f01a1afd5e9a247 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 May 2021 18:31:23 +0100 Subject: [PATCH 14/53] Updates --- docs/source/index.rst | 6 ++ docs/source/reference/template.rst | 75 +++++++++++++++++++++ docs/source/template/data.rst | 6 +- docs/source/template/docs.rst | 29 ++++++++ docs/source/template/tests.rst | 74 ++++++++++++++++++++ tests/examples/test_scripts.py | 11 +++ tests/template/classification/test_data.py | 3 +- tests/template/classification/test_model.py | 14 ++-- 8 files changed, 207 insertions(+), 11 deletions(-) create mode 100644 docs/source/reference/template.rst create mode 100644 docs/source/template/docs.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 1ce28229d6..2339dff942 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -60,6 +60,12 @@ Lightning Flash template/optional template/examples template/tests + template/docs + +.. toctree:: + :hidden: + + reference/template Indices and tables ================== diff --git a/docs/source/reference/template.rst b/docs/source/reference/template.rst new file mode 100644 index 0000000000..7ed1407d69 --- /dev/null +++ b/docs/source/reference/template.rst @@ -0,0 +1,75 @@ + +.. _template: + +######## +Template +######## + +******** +The task +******** + +Here you should add a description of your task. For example: +Classification is the task of assigning one of a number of classes to each data point. +The :class:`~flash.template.TemplateSKLearnClassifier` is a :class:`~flash.core.model.Task` for classifying the datasets included with scikit-learn. + +------ + +********* +Inference +********* + +Here, you should add a short intro to your predict example, and then use ``literalinclude`` to add it. + +.. note:: We skip the first 14 lines as they are just the copyright notice. + +Our predict example uses a model pre-trained on the iris data. + +.. literalinclude:: ../../../flash_examples/predict/template.py + :language: python + :lines: 14- + +For more advanced inference options, see :ref:`predictions`. + +------ + +******** +Training +******** + +In this section, we breifly describe the data, and then ``literalinclude`` our finetuning example. + +Now we'll train on Fisher's classic iris data. +It contains 150 records with four features (sepal length, sepal width, petal length, and petal width) in three classes (species of Iris: setosa, virginica and versicolor). + +Now all we need is to train our task! + +.. literalinclude:: ../../../flash_examples/finetuning/template.py + :language: python + :lines: 14- + +------ + +************* +API reference +************* + +We usually include the API reference for the :class:`~flash.core.model.Task` and :class:`~flash.core.data.data_module.DataModule`. +You can optionally add the other classes you've implemented. +To add the API reference, use the ``autoclass`` directive. + +.. _template_classifier: + +TemplateSKLearnClassifier +------------------------- + +.. autoclass:: flash.template.TemplateSKLearnClassifier + :members: + :exclude-members: forward + +.. _template_data: + +TemplateData +------------ + +.. autoclass:: flash.template.TemplateData diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index d989f0b5dd..2ca67a3e52 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -7,7 +7,7 @@ The Data The first step to contributing a task is to implement the classes we need to load some data. Inside ``data.py`` you should implement: -#. one or more :class:`~flash.core.data.data_source.DataSource` classes +#. zero or more :class:`~flash.core.data.data_source.DataSource` classes #. a :class:`~flash.core.data.process.Preprocess` #. a :class:`~flash.core.data.data_module.DataModule` #. a :class:`~flash.core.data.callbacks.BaseVisualization` *(optional)* @@ -17,7 +17,7 @@ DataSource ^^^^^^^^^^ The :class:`~flash.core.data.data_source.DataSource` implementations describe how data from particular sources (like folders, files, tensors, etc.) should be loaded. -At a minimum you will require one :class:`~flash.core.data.data_source.DataSource` implementation, but if you want to support a few different ways of loading data for your task, the more the merrier! +If you just want to support :meth:`flash.core.data.data_module.DataModule.from_datasets` you won't need a :class:`~flash.core.data.data_source.DataSource`, but if you want to support a few different ways of loading data for your task, the more the merrier! Each :class:`~flash.core.data.data_source.DataSource` has a ``load_data`` method and a ``load_sample`` method. The ``load_data`` method accepts some dataset metadata (e.g. a folder name) and produces a sequence or iterable of samples or sample metadata. The ``load_sample`` method then takes as input a single element from the output of ``load_data`` and returns a sample. @@ -69,7 +69,7 @@ A :class:`~flash.core.data.data_source.DataSource` is not the same as a :class:` A :class:`torch.utils.data.Dataset` knows about the data, whereas a :class:`~flash.core.data.data_source.DataSource` only know about how to load the data. So it's possible for a single :class:`~flash.core.data.data_source.DataSource` to create more than one :class:`~torch.utils.data.Dataset`. It's also fine for the output of the ``load_data`` method to just be a :class:`torch.utils.data.Dataset` instance. -You don't need to re-write custom datasets, just instantiate them in ``load_data`` similarly to what we did with the ``TemplateSKLearnDataSource``. +You don't need to re-write custom datasets, either use :meth:`flash.core.data.data_module.DataModule.from_datasets` or just instantiate them in ``load_data`` similarly to what we did with the ``TemplateSKLearnDataSource``. For example, the ``load_data`` of the ``VideoClassificationPathsDataSource`` just creates an :class:`~pytorchvideo.data.EncodedVideoDataset` from the given folder. Here's how it looks (from ``video/classification.data.py``): diff --git a/docs/source/template/docs.rst b/docs/source/template/docs.rst new file mode 100644 index 0000000000..1cd7502cda --- /dev/null +++ b/docs/source/template/docs.rst @@ -0,0 +1,29 @@ +.. _contributing_docs: + +********* +The Docs +********* + +The final step is to add some docs. +For each :class:`~flash.core.model.Task` in Flash, we have a docs page in ``docs/source/reference``. +You should create a ``.rst`` file there with the following: + +- a brief description of the task +- the predict example +- the finetuning example +- any relevant API reference + +Here are the contents of ``docs/source/reference/template.rst`` which breaks down each of these steps: + +.. literalinclude:: ../reference/template.rst + :language: rest + +:ref:`Here's the rendered doc page!