-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
HenryJia: auto-move data decorator (#1905)
* First attempt at auto-moving data for inference * Correct my copypaste errors * Correct for if device is CPU * Get rid of the WIP code I accidentally added * Add tests * Make tests more foolproof * Make sure we stick with pep8 formatting * Clarify docs a little * Apply suggestions from code review * Get everything working again hopefully * refactor and added hook variant a variant b add test revert rename add changelog docs * move changelog entry to top * Move data transfer to utilities * Add back in warnings for autotransfer * Get rid of the test code I ended up accidentally commiting again * Add docs any changelog * Correct PR number in Changelog * Correct changelog * Update data.py * Update test_cpu.py * make a decorator * type hint * changelog * changelog * remove old function * import * test for decorator * fix test * remove old test * doctest * apply decorator directly * convert doctest to code block * prevent side effects in tests * fix merge * update forward docs * update docs * added docs in section "deployment / prediction" * update changelog Co-authored-by: Hengjian Jia <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: William Falcon <[email protected]>
- Loading branch information
1 parent
a5cc4e8
commit 22d9464
Showing
5 changed files
with
94 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import pytest | ||
import torch | ||
|
||
from tests.base import EvalModelTemplate | ||
from pytorch_lightning.core.decorators import auto_move_data | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") | ||
@pytest.mark.parametrize(['src_device', 'dest_device'], [ | ||
pytest.param(torch.device('cpu'), torch.device('cpu')), | ||
pytest.param(torch.device('cpu', 0), torch.device('cuda', 0)), | ||
pytest.param(torch.device('cuda', 0), torch.device('cpu')), | ||
pytest.param(torch.device('cuda', 0), torch.device('cuda', 0)), | ||
]) | ||
def test_auto_move_data(src_device, dest_device): | ||
""" Test that the decorator moves the data to the device the model is on. """ | ||
|
||
class CurrentModel(EvalModelTemplate): | ||
pass | ||
|
||
# apply the decorator | ||
CurrentModel.forward = auto_move_data(CurrentModel.forward) | ||
|
||
model = CurrentModel() | ||
model = model.to(dest_device) | ||
model.prepare_data() | ||
loader = model.train_dataloader() | ||
x, y, = next(iter(loader)) | ||
x = x.flatten(1) | ||
|
||
# test that data on source device gets moved to destination device | ||
x = x.to(src_device) | ||
assert model(x).device == dest_device, "Automoving data to same device as model failed" |