Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Apr 12, 2021
1 parent 5ca0c3e commit 1c33a4a
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 155 deletions.
201 changes: 87 additions & 114 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
Tutorial: Creating a Custom Task
================================

In this tutorial we will go over the process of creating a custom task,
along with a custom data module.
In this tutorial we will go over the process of creating a custom :class:`~flash.core.model.Task`,
along with a custom :class:`~flash.data.data_module.DataModule`.


The tutorial objective is to create a ``RegressionTask`` to learn to predict if someone has ``diabetes`` or not.
The diabetes (stored as a numpy dataset).

.. note::

Find the complete tutorial example at
`flash_examples/custom_task.py <https://github.com/PyTorchLightning/lightning-flash/blob/revamp_doc/flash_examples/custom_task.py>`_.


1. Imports
-----------
----------


.. testcode:: python
Expand All @@ -23,12 +33,82 @@ along with a custom data module.
from flash.data.auto_dataset import AutoDataset
from flash.data.process import Postprocess, Preprocess

# set the random seeds.
seed_everything(42)

2.a The DataModule API

2. The Task: Linear regression
-------------------------------

Here we create a basic linear regression task by subclassing
:class:`~flash.core.model.Task`. For the majority of tasks, you will likely only need to
override the ``__init__`` and ``forward`` methods.

.. testcode::

class RegressionTask(flash.Task):

def __init__(self, num_inputs, learning_rate=0.001, metrics=None):
# what kind of model do we want?
model = nn.Linear(num_inputs, 1)

# what loss function do we want?
loss_fn = torch.nn.functional.mse_loss

# what optimizer to do we want?
optimizer = torch.optim.SGD

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
)

def forward(self, x):
# we don't actually need to override this method for this example
return self.model(x)

.. note::

Lightning Flash provides an API to register models within a store.
Check out :ref:`registry`.


Where is the training step?
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Most models can be trained simply by passing the output of ``forward``
to the supplied ``loss_fn``, and then passing the resulting loss to the
supplied ``optimizer``. If you need a more custom configuration, you can
override ``step`` (which is called for training, validation, and
testing) or override ``training_step``, ``validation_step``, and
``test_step`` individually. These methods behave identically to PyTorch
Lightning’s
`methods <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#methods>`__.

Here is the pseudo code behind :class:`~flash.core.model.Task` step.

Example::

def step(self, batch: Any, batch_idx: int) -> Any:
"""
The training/validation/test step. Override for custom behavior.
"""
x, y = batch
y_hat = self(x)
# compute the logs, loss and metrics as an output dictionary
return output


3.a The DataModule API
----------------------

First, let's design the user-facing API. The ``NumpyDataModule`` will provide a ``from_xy_dataset`` helper ``classmethod``.
First, let's first design the user-facing API.
We are going to create a ``NumpyDataModule`` class subclassing :class:`~flash.data.data_module.DataModule`.
This ``NumpyDataModule`` will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate
:class:`~flash.data.data_module.DataModule` from x, y numpy arrays.

Example::

Expand Down Expand Up @@ -83,7 +163,7 @@ Example::
Make sure to instantiate your :class:`~flash.data.data_module.DataModule` with this helper if you rely on :class:`~flash.data.process.Preprocess`
objects.

2.b The Preprocess API
3.b The Preprocess API
----------------------

.. note::
Expand Down Expand Up @@ -130,72 +210,6 @@ Example::
return torch.from_numpy(sample).float()


3. The Task: Linear regression
-------------------------------

Here we create a basic linear regression task by subclassing
:class:`~flash.core.model.Task`. For the majority of tasks, you will likely only need to
override the ``__init__`` and ``forward`` methods.

.. testcode::

class RegressionTask(flash.Task):

def __init__(self, num_inputs, learning_rate=0.001, metrics=None):
# what kind of model do we want?
model = nn.Linear(num_inputs, 1)

# what loss function do we want?
loss_fn = torch.nn.functional.mse_loss

# what optimizer to do we want?
optimizer = torch.optim.SGD

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
)

def forward(self, x):
# we don't actually need to override this method for this example
return self.model(x)

.. note::

Lightning Flash provides an API to register models within a store.
Check out :ref:`registry`.


Where is the training step?
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Most models can be trained simply by passing the output of ``forward``
to the supplied ``loss_fn``, and then passing the resulting loss to the
supplied ``optimizer``. If you need a more custom configuration, you can
override ``step`` (which is called for training, validation, and
testing) or override ``training_step``, ``validation_step``, and
``test_step`` individually. These methods behave identically to PyTorch
Lightning’s
`methods <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#methods>`__.

Here is the pseudo code behind :class:`~flash.core.model.Task` step.

Example::

def step(self, batch: Any, batch_idx: int) -> Any:
"""
The training/validation/test step. Override for custom behavior.
"""
x, y = batch
y_hat = self(x)
# compute the logs, loss and metrics as an output dictionary
return output



4. Fitting
----------

Expand All @@ -214,6 +228,7 @@ supplying the task itself, and the associated data:
trainer = flash.Trainer(max_epochs=1000)
trainer.fit(model, data)
5. Predicting
-------------

Expand All @@ -233,45 +248,3 @@ few examples from the test set of our data:
predictions = model.predict(predict_data)
print(predictions)
#out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])]
6. Customize PostProcess
------------------------

To customize the postprocessing of this task, you can create a :class:`~flash.data.process.Postprocess` object and assign it to your model as follows:

.. code:: python
class CustomPostprocess(Postprocess):
THRESHOLD = 14.72
def predict_per_sample_transform(self, pred: Any) -> Any:
if pred > self.THRESHOLD:
def send_slack_message(pred):
print(f"This prediction: {pred} is above the threshold: {self.THRESHOLD}")
send_slack_message(pred)
return pred
class RegressionTask(flash.Task):
# ``postprocess_cls`` is a special attribute name used internally
# to instantiate your Postprocess.
postprocess_cls = CustomPostprocess
...
And when running predict one more time.

.. code:: python
predict_data = ...
predictions = model.predict(predict_data)
# out: This prediction: tensor([14.7288]) is above the threshold: 14.72
print(predictions)
# out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])]
4 changes: 2 additions & 2 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ Here are common terms you need to be familiar with:
- The :class:`~flash.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects.
* - :class:`~flash.data.process.Preprocess`
- The :class:`~flash.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic.
The :class:`~flash.data.process.Preprocess` provides multiple hooks and :meth:`~flash.data.process.Preprocess.load_data`
and :meth:`~flash.data.process.Preprocess.load_sample` functions are used to replace a traditional `Dataset`.
The :class:`~flash.data.process.Preprocess` provides multiple hooks such as :meth:`~flash.data.process.Preprocess.load_data`
and :meth:`~flash.data.process.Preprocess.load_sample` which are used to replace a traditional `Dataset` logic.
Flash DataPipeline contains a system to call the right hooks when needed.
The :class:`~flash.data.process.Preprocess` hooks covers from data-loading to model forwarding.
* - :class:`~flash.data.process.Postprocess`
Expand Down
11 changes: 9 additions & 2 deletions docs/source/general/registry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ Registry
Available Registries
********************

Registries are Flash internal key-value database to store mapping between a name and a function.
It helps organize code and make the functions accessible all across the ``Flash`` codebase directly from the key.
Registries are Flash internal key-value database to store a mapping between a name and a function.

In simple words, they are just advanced dictionary storing a function from a key string.

Registries help organize code and make the functions accessible all across the ``Flash`` codebase.
Each Flash ``Task`` can have several registries as static attributes.
It enables to quickly experiment with your backbone functions or use our long list of available backbones.

Expand All @@ -18,6 +21,10 @@ Example::
from flash.vision import ImageClassifier
from flash.core.registry import FlashRegistry

class MyImageClassifier(ImageClassifier):

backbones = FlashRegistry("backbones")

@MyImageClassifier.backbones(name="username/my_backbone")
def fn():
# Create backbone and backbone output dimension (`num_features`)
Expand Down
2 changes: 1 addition & 1 deletion flash/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage):
.. note::
As the :class:`~flash.data.process.Preprocess` hooks are injected within
the threaded workers for the DataLoader,
the threaded workers of the DataLoader,
the data won't be accessible when using ``num_workers > 0``.
"""
Expand Down
14 changes: 7 additions & 7 deletions flash/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@ def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once a sample has been loaded using ``load_sample``."""

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``pre_tensor_transform`` have been applied to a sample."""
"""Called once ``pre_tensor_transform`` has been applied to a sample."""

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``to_tensor_transform`` have been applied to a sample."""
"""Called once ``to_tensor_transform`` has been applied to a sample."""

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
"""Called once ``post_tensor_transform`` have been applied to a sample."""
"""Called once ``post_tensor_transform`` has been applied to a sample."""

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
"""Called once ``per_batch_transform`` have been applied to a batch."""
"""Called once ``per_batch_transform`` has been applied to a batch."""

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
"""Called once ``collate`` have been applied to a sequence of samples."""
"""Called once ``collate`` has been applied to a sequence of samples."""

def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``per_sample_transform_on_device`` have been applied to a sample."""
"""Called once ``per_sample_transform_on_device`` has been applied to a sample."""

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
"""Called once ``per_batch_transform_on_device`` have been applied to a sample."""
"""Called once ``per_batch_transform_on_device`` has been applied to a sample."""


class ControlFlow(FlashCallback):
Expand Down
Loading

0 comments on commit 1c33a4a

Please sign in to comment.