Skip to content

Commit

Permalink
HenryJia: auto-move data decorator (#1905)
Browse files Browse the repository at this point in the history
* First attempt at auto-moving data for inference

* Correct my copypaste errors

* Correct for if device is CPU

* Get rid of the WIP code I accidentally added

* Add tests

* Make tests more foolproof

* Make sure we stick with pep8 formatting

* Clarify docs a little

* Apply suggestions from code review

* Get everything working again hopefully

* refactor and added hook


variant a


variant b


add test


revert rename


add changelog


docs

* move changelog entry to top

* Move data transfer to utilities

* Add back in warnings for autotransfer

* Get rid of the test code I ended up accidentally commiting again

* Add docs any changelog

* Correct PR number in Changelog

* Correct changelog

* Update data.py

* Update test_cpu.py

* make a decorator

* type hint

* changelog

* changelog

* remove old function

* import

* test for decorator

* fix test

* remove old test

* doctest

* apply decorator directly

* convert doctest to code block

* prevent side effects in tests

* fix merge

* update forward docs

* update docs

* added docs in section "deployment / prediction"

* update changelog

Co-authored-by: Hengjian Jia <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: William Falcon <[email protected]>
  • Loading branch information
4 people authored Jun 15, 2020
1 parent a5cc4e8 commit 22d9464
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))
- Added a decorator `auto_move_data` that moves data to the correct device when using the LightningModule for inference ([#1905](https://github.com/PyTorchLightning/pytorch-lightning/pull/1905))

### Changed

Expand Down
52 changes: 52 additions & 0 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from functools import wraps
from typing import Callable

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn


Expand All @@ -12,3 +18,49 @@ def data_loader(fn):
def inner_fx(self):
return fn(self)
return inner_fx


def auto_move_data(fn: Callable) -> Callable:
"""
Decorator for :class:`~pytorch_lightning.core.lightning.LightningModule` methods for which
input arguments should be moved automatically to the correct device.
It as no effect if applied to a method of an object that is not an instance of
:class:`~pytorch_lightning.core.lightning.LightningModule` and is typically applied to ``__call__``
or ``forward``.
Args:
fn: A LightningModule method for which the arguments should be moved to the device
the parameters are on.
Example:
.. code-block:: python
# directly in the source code
class LitModel(LightningModule):
@auto_move_data
def forward(self, x):
return x
# or outside
LitModel.forward = auto_move_data(LitModel.forward)
model = LitModel()
model = model.to('cuda')
model(torch.zeros(1, 3))
# input gets moved to device
# tensor([[0., 0., 0.]], device='cuda:0')
"""
@wraps(fn)
def auto_transfer_args(self, *args, **kwargs):
if not isinstance(self, LightningModule):
return fn(self, *args, **kwargs)

args = self.transfer_batch_to_device(args, self.device)
kwargs = self.transfer_batch_to_device(kwargs, self.device)
return fn(self, *args, **kwargs)

return auto_transfer_args
3 changes: 3 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def forward(self, *args, **kwargs):
This makes it easy to write a complex system for training with the outputs
you'd want in a prediction setting.
You may also find the :func:`~pytorch_lightning.core.decorators.auto_move_data` decorator useful
when using the module outside Lightning in a production setting.
Args:
*args: Whatever you decide to pass into the forward method.
**kwargs: Keyword arguments are also possible.
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def forward(self, x):
out = pretrained_model(x)
api_write({'response': out}
You may wish to run the model on a variety of devices. Instead of moving the data
manually to the correct device, decorate the forward method (or any other method you use for inference)
with :func:`~pytorch_lightning.core.decorators.auto_move_data` and Lightning will take care of the rest.
------------
Reproducibility
Expand Down
33 changes: 33 additions & 0 deletions tests/core/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
import torch

from tests.base import EvalModelTemplate
from pytorch_lightning.core.decorators import auto_move_data


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.parametrize(['src_device', 'dest_device'], [
pytest.param(torch.device('cpu'), torch.device('cpu')),
pytest.param(torch.device('cpu', 0), torch.device('cuda', 0)),
pytest.param(torch.device('cuda', 0), torch.device('cpu')),
pytest.param(torch.device('cuda', 0), torch.device('cuda', 0)),
])
def test_auto_move_data(src_device, dest_device):
""" Test that the decorator moves the data to the device the model is on. """

class CurrentModel(EvalModelTemplate):
pass

# apply the decorator
CurrentModel.forward = auto_move_data(CurrentModel.forward)

model = CurrentModel()
model = model.to(dest_device)
model.prepare_data()
loader = model.train_dataloader()
x, y, = next(iter(loader))
x = x.flatten(1)

# test that data on source device gets moved to destination device
x = x.to(src_device)
assert model(x).device == dest_device, "Automoving data to same device as model failed"

0 comments on commit 22d9464

Please sign in to comment.