From 1c33a4ac123c1f6b4900cf7843798ea4175e61d9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 12 Apr 2021 21:03:41 +0100 Subject: [PATCH] update --- docs/source/custom_task.rst | 201 +++++++++++++------------------ docs/source/general/data.rst | 4 +- docs/source/general/registry.rst | 11 +- flash/data/base_viz.py | 2 +- flash/data/callback.py | 14 +-- flash_examples/custom_task.py | 43 +++---- 6 files changed, 120 insertions(+), 155 deletions(-) diff --git a/docs/source/custom_task.rst b/docs/source/custom_task.rst index 7baaf53626..eeb3ef7a1d 100644 --- a/docs/source/custom_task.rst +++ b/docs/source/custom_task.rst @@ -1,11 +1,21 @@ Tutorial: Creating a Custom Task ================================ -In this tutorial we will go over the process of creating a custom task, -along with a custom data module. +In this tutorial we will go over the process of creating a custom :class:`~flash.core.model.Task`, +along with a custom :class:`~flash.data.data_module.DataModule`. + + +The tutorial objective is to create a ``RegressionTask`` to learn to predict if someone has ``diabetes`` or not. +The diabetes (stored as a numpy dataset). + +.. note:: + + Find the complete tutorial example at + `flash_examples/custom_task.py `_. + 1. Imports ------------ +---------- .. testcode:: python @@ -23,12 +33,82 @@ along with a custom data module. from flash.data.auto_dataset import AutoDataset from flash.data.process import Postprocess, Preprocess + # set the random seeds. seed_everything(42) -2.a The DataModule API + +2. The Task: Linear regression +------------------------------- + +Here we create a basic linear regression task by subclassing +:class:`~flash.core.model.Task`. For the majority of tasks, you will likely only need to +override the ``__init__`` and ``forward`` methods. + +.. testcode:: + + class RegressionTask(flash.Task): + + def __init__(self, num_inputs, learning_rate=0.001, metrics=None): + # what kind of model do we want? + model = nn.Linear(num_inputs, 1) + + # what loss function do we want? + loss_fn = torch.nn.functional.mse_loss + + # what optimizer to do we want? + optimizer = torch.optim.SGD + + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + metrics=metrics, + learning_rate=learning_rate, + ) + + def forward(self, x): + # we don't actually need to override this method for this example + return self.model(x) + +.. note:: + + Lightning Flash provides an API to register models within a store. + Check out :ref:`registry`. + + +Where is the training step? +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Most models can be trained simply by passing the output of ``forward`` +to the supplied ``loss_fn``, and then passing the resulting loss to the +supplied ``optimizer``. If you need a more custom configuration, you can +override ``step`` (which is called for training, validation, and +testing) or override ``training_step``, ``validation_step``, and +``test_step`` individually. These methods behave identically to PyTorch +Lightning’s +`methods `__. + +Here is the pseudo code behind :class:`~flash.core.model.Task` step. + +Example:: + + def step(self, batch: Any, batch_idx: int) -> Any: + """ + The training/validation/test step. Override for custom behavior. + """ + x, y = batch + y_hat = self(x) + # compute the logs, loss and metrics as an output dictionary + return output + + +3.a The DataModule API ---------------------- -First, let's design the user-facing API. The ``NumpyDataModule`` will provide a ``from_xy_dataset`` helper ``classmethod``. +First, let's first design the user-facing API. +We are going to create a ``NumpyDataModule`` class subclassing :class:`~flash.data.data_module.DataModule`. +This ``NumpyDataModule`` will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate +:class:`~flash.data.data_module.DataModule` from x, y numpy arrays. Example:: @@ -83,7 +163,7 @@ Example:: Make sure to instantiate your :class:`~flash.data.data_module.DataModule` with this helper if you rely on :class:`~flash.data.process.Preprocess` objects. -2.b The Preprocess API +3.b The Preprocess API ---------------------- .. note:: @@ -130,72 +210,6 @@ Example:: return torch.from_numpy(sample).float() -3. The Task: Linear regression -------------------------------- - -Here we create a basic linear regression task by subclassing -:class:`~flash.core.model.Task`. For the majority of tasks, you will likely only need to -override the ``__init__`` and ``forward`` methods. - -.. testcode:: - - class RegressionTask(flash.Task): - - def __init__(self, num_inputs, learning_rate=0.001, metrics=None): - # what kind of model do we want? - model = nn.Linear(num_inputs, 1) - - # what loss function do we want? - loss_fn = torch.nn.functional.mse_loss - - # what optimizer to do we want? - optimizer = torch.optim.SGD - - super().__init__( - model=model, - loss_fn=loss_fn, - optimizer=optimizer, - metrics=metrics, - learning_rate=learning_rate, - ) - - def forward(self, x): - # we don't actually need to override this method for this example - return self.model(x) - -.. note:: - - Lightning Flash provides an API to register models within a store. - Check out :ref:`registry`. - - -Where is the training step? -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Most models can be trained simply by passing the output of ``forward`` -to the supplied ``loss_fn``, and then passing the resulting loss to the -supplied ``optimizer``. If you need a more custom configuration, you can -override ``step`` (which is called for training, validation, and -testing) or override ``training_step``, ``validation_step``, and -``test_step`` individually. These methods behave identically to PyTorch -Lightning’s -`methods `__. - -Here is the pseudo code behind :class:`~flash.core.model.Task` step. - -Example:: - - def step(self, batch: Any, batch_idx: int) -> Any: - """ - The training/validation/test step. Override for custom behavior. - """ - x, y = batch - y_hat = self(x) - # compute the logs, loss and metrics as an output dictionary - return output - - - 4. Fitting ---------- @@ -214,6 +228,7 @@ supplying the task itself, and the associated data: trainer = flash.Trainer(max_epochs=1000) trainer.fit(model, data) + 5. Predicting ------------- @@ -233,45 +248,3 @@ few examples from the test set of our data: predictions = model.predict(predict_data) print(predictions) #out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])] - - -6. Customize PostProcess ------------------------- - -To customize the postprocessing of this task, you can create a :class:`~flash.data.process.Postprocess` object and assign it to your model as follows: - -.. code:: python - - class CustomPostprocess(Postprocess): - - THRESHOLD = 14.72 - - def predict_per_sample_transform(self, pred: Any) -> Any: - if pred > self.THRESHOLD: - - def send_slack_message(pred): - print(f"This prediction: {pred} is above the threshold: {self.THRESHOLD}") - - send_slack_message(pred) - return pred - - - class RegressionTask(flash.Task): - - # ``postprocess_cls`` is a special attribute name used internally - # to instantiate your Postprocess. - postprocess_cls = CustomPostprocess - - ... - -And when running predict one more time. - -.. code:: python - - predict_data = ... - - predictions = model.predict(predict_data) - # out: This prediction: tensor([14.7288]) is above the threshold: 14.72 - - print(predictions) - # out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])] diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index ece79b78e6..da21d3c09a 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -23,8 +23,8 @@ Here are common terms you need to be familiar with: - The :class:`~flash.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. * - :class:`~flash.data.process.Preprocess` - The :class:`~flash.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic. - The :class:`~flash.data.process.Preprocess` provides multiple hooks and :meth:`~flash.data.process.Preprocess.load_data` - and :meth:`~flash.data.process.Preprocess.load_sample` functions are used to replace a traditional `Dataset`. + The :class:`~flash.data.process.Preprocess` provides multiple hooks such as :meth:`~flash.data.process.Preprocess.load_data` + and :meth:`~flash.data.process.Preprocess.load_sample` which are used to replace a traditional `Dataset` logic. Flash DataPipeline contains a system to call the right hooks when needed. The :class:`~flash.data.process.Preprocess` hooks covers from data-loading to model forwarding. * - :class:`~flash.data.process.Postprocess` diff --git a/docs/source/general/registry.rst b/docs/source/general/registry.rst index bd683a2061..ac499d5f58 100644 --- a/docs/source/general/registry.rst +++ b/docs/source/general/registry.rst @@ -8,8 +8,11 @@ Registry Available Registries ******************** -Registries are Flash internal key-value database to store mapping between a name and a function. -It helps organize code and make the functions accessible all across the ``Flash`` codebase directly from the key. +Registries are Flash internal key-value database to store a mapping between a name and a function. + +In simple words, they are just advanced dictionary storing a function from a key string. + +Registries help organize code and make the functions accessible all across the ``Flash`` codebase. Each Flash ``Task`` can have several registries as static attributes. It enables to quickly experiment with your backbone functions or use our long list of available backbones. @@ -18,6 +21,10 @@ Example:: from flash.vision import ImageClassifier from flash.core.registry import FlashRegistry + class MyImageClassifier(ImageClassifier): + + backbones = FlashRegistry("backbones") + @MyImageClassifier.backbones(name="username/my_backbone") def fn(): # Create backbone and backbone output dimension (`num_features`) diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index 6cdf6a03dc..7efe975e43 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -89,7 +89,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage): .. note:: As the :class:`~flash.data.process.Preprocess` hooks are injected within - the threaded workers for the DataLoader, + the threaded workers of the DataLoader, the data won't be accessible when using ``num_workers > 0``. """ diff --git a/flash/data/callback.py b/flash/data/callback.py index 92ab35b981..df8ad91600 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -15,25 +15,25 @@ def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: """Called once a sample has been loaded using ``load_sample``.""" def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - """Called once ``pre_tensor_transform`` have been applied to a sample.""" + """Called once ``pre_tensor_transform`` has been applied to a sample.""" def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - """Called once ``to_tensor_transform`` have been applied to a sample.""" + """Called once ``to_tensor_transform`` has been applied to a sample.""" def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: - """Called once ``post_tensor_transform`` have been applied to a sample.""" + """Called once ``post_tensor_transform`` has been applied to a sample.""" def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: - """Called once ``per_batch_transform`` have been applied to a batch.""" + """Called once ``per_batch_transform`` has been applied to a batch.""" def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: - """Called once ``collate`` have been applied to a sequence of samples.""" + """Called once ``collate`` has been applied to a sequence of samples.""" def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: - """Called once ``per_sample_transform_on_device`` have been applied to a sample.""" + """Called once ``per_sample_transform_on_device`` has been applied to a sample.""" def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: - """Called once ``per_batch_transform_on_device`` have been applied to a sample.""" + """Called once ``per_batch_transform_on_device`` has been applied to a sample.""" class ControlFlow(FlashCallback): diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py index 82f829515c..db337f0471 100644 --- a/flash_examples/custom_task.py +++ b/flash_examples/custom_task.py @@ -5,7 +5,7 @@ from pytorch_lightning import seed_everything from sklearn import datasets from sklearn.model_selection import train_test_split -from torch import nn +from torch import nn, Tensor import flash from flash.data.auto_dataset import AutoDataset @@ -13,24 +13,10 @@ seed_everything(42) +ND = np.ndarray -class CustomPostprocess(Postprocess): - THRESHOLD = 14.72 - - def predict_per_sample_transform(self, pred: Any) -> Any: - if pred > self.THRESHOLD: - - def send_slack_message(pred): - print(f"This prediction: {pred} is above the threshold: {self.THRESHOLD}") - - send_slack_message(pred) - return pred - - -class LinearRegression(flash.Task): - - postprocess_cls = CustomPostprocess +class RegressionTask(flash.Task): def __init__(self, num_inputs, learning_rate=0.001, metrics=None): # what kind of model do we want? @@ -57,32 +43,30 @@ def forward(self, x): class NumpyPreprocess(Preprocess): - def load_data(self, data: Tuple[np.ndarray, np.ndarray], dataset: AutoDataset) -> List[Tuple[np.ndarray, float]]: + def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]: if self.training: dataset.num_inputs = data[0].shape[1] return [(x, y) for x, y in zip(*data)] - def to_tensor_transform(self, sample: Any) -> Tuple[torch.Tensor, torch.Tensor]: + def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]: x, y = sample x = torch.from_numpy(x).float() y = torch.tensor(y, dtype=torch.float) return x, y - def predict_load_data(self, data: np.ndarray) -> np.ndarray: + def predict_load_data(self, data: ND) -> ND: return data - def predict_to_tensor_transform(self, sample: np.ndarray) -> np.ndarray: + def predict_to_tensor_transform(self, sample: ND) -> ND: return torch.from_numpy(sample).float() -class SklearnDataModule(flash.DataModule): - - preprocess_cls = NumpyPreprocess +class NumpyDataModule(flash.DataModule): @classmethod - def from_dataset(cls, x: np.ndarray, y: np.ndarray, batch_size: int = 64, num_workers: int = 0): + def from_dataset(cls, x: ND, y: ND, preprocess: Preprocess, batch_size: int = 64, num_workers: int = 0): - preprocess = cls.preprocess_cls() + preprocess = preprocess x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0) @@ -93,12 +77,13 @@ def from_dataset(cls, x: np.ndarray, y: np.ndarray, batch_size: int = 64, num_wo batch_size=batch_size, num_workers=num_workers ) - dm.num_inputs = dm._train_ds.num_inputs + dm.num_inputs = dm.train_dataset.num_inputs return dm -datamodule = SklearnDataModule.from_dataset(*datasets.load_diabetes(return_X_y=True)) -model = LinearRegression(num_inputs=datamodule.num_inputs) +x, y = datasets.load_diabetes(return_X_y=True) +datamodule = NumpyDataModule.from_dataset(x, y, NumpyPreprocess()) +model = RegressionTask(num_inputs=datamodule.num_inputs) trainer = flash.Trainer(max_epochs=10, progress_bar_refresh_rate=20) trainer.fit(model, datamodule=datamodule)