Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up dataloader logic #926

Merged
merged 80 commits into from
Feb 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
15bfd14
added get dataloaders directly using a getter
williamFalcon Feb 24, 2020
6c7a37e
deleted decorator
williamFalcon Feb 24, 2020
38e6991
added prepare_data hook
williamFalcon Feb 24, 2020
37be7e7
refactored dataloader init
williamFalcon Feb 24, 2020
db1115e
refactored dataloader init
williamFalcon Feb 24, 2020
dbe3fc0
added dataloader reset flag and main loop
williamFalcon Feb 24, 2020
bb642b8
added dataloader reset flag and main loop
williamFalcon Feb 24, 2020
6d646fd
added dataloader reset flag and main loop
williamFalcon Feb 24, 2020
412899c
made changes
williamFalcon Feb 24, 2020
252713f
made changes
williamFalcon Feb 24, 2020
de77ece
made changes
williamFalcon Feb 24, 2020
f723574
made changes
williamFalcon Feb 24, 2020
9009c22
made changes
williamFalcon Feb 24, 2020
74984e3
made changes
williamFalcon Feb 24, 2020
838879b
made changes
williamFalcon Feb 24, 2020
14f3a1d
made changes
williamFalcon Feb 24, 2020
93f6b19
made changes
williamFalcon Feb 24, 2020
b021212
made changes
williamFalcon Feb 24, 2020
45db9be
made changes
williamFalcon Feb 24, 2020
20b5c62
made changes
williamFalcon Feb 24, 2020
189bbb1
made changes
williamFalcon Feb 24, 2020
0a45b1a
made changes
williamFalcon Feb 24, 2020
767ad23
made changes
williamFalcon Feb 24, 2020
56c0654
made changes
williamFalcon Feb 24, 2020
783b5c7
made changes
williamFalcon Feb 24, 2020
8183e82
made changes
williamFalcon Feb 24, 2020
2a2a6ef
made changes
williamFalcon Feb 24, 2020
8347a18
made changes
williamFalcon Feb 24, 2020
176d62d
made changes
williamFalcon Feb 24, 2020
4e3fb96
made changes
williamFalcon Feb 24, 2020
51cc57f
made changes
williamFalcon Feb 24, 2020
c55bb0d
made changes
williamFalcon Feb 24, 2020
6fe933b
made changes
williamFalcon Feb 24, 2020
d43c9e7
made changes
williamFalcon Feb 24, 2020
05ab2db
made changes
williamFalcon Feb 24, 2020
d165b46
made changes
williamFalcon Feb 24, 2020
c5535e8
made changes
williamFalcon Feb 24, 2020
1e56281
made changes
williamFalcon Feb 24, 2020
7623a27
made changes
williamFalcon Feb 24, 2020
36697f3
made changes
williamFalcon Feb 24, 2020
803e72d
made changes
williamFalcon Feb 24, 2020
a8f3e19
made changes
williamFalcon Feb 24, 2020
53598e2
made changes
williamFalcon Feb 24, 2020
cddbac8
made changes
williamFalcon Feb 24, 2020
e42b1b7
made changes
williamFalcon Feb 24, 2020
b83a7d7
made changes
williamFalcon Feb 24, 2020
d117727
made changes
williamFalcon Feb 24, 2020
6e57368
made changes
williamFalcon Feb 24, 2020
de64175
made changes
williamFalcon Feb 24, 2020
55d302d
made changes
williamFalcon Feb 24, 2020
df70d2e
made changes
williamFalcon Feb 24, 2020
3635e61
made changes
williamFalcon Feb 25, 2020
f7a6382
made changes
williamFalcon Feb 25, 2020
abd2126
made changes
williamFalcon Feb 25, 2020
90eef5e
fixed bad loaders
williamFalcon Feb 25, 2020
d047618
fixed bad loaders
williamFalcon Feb 25, 2020
eb9a380
fixed bad loaders
williamFalcon Feb 25, 2020
cb4f761
fixed bad loaders
williamFalcon Feb 25, 2020
cb8e977
fixed bad loaders
williamFalcon Feb 25, 2020
82ec6ce
fixed bad loaders
williamFalcon Feb 25, 2020
617dd32
fixed bad loaders
williamFalcon Feb 25, 2020
97cf4c0
fixed bad loaders
williamFalcon Feb 25, 2020
1671fbc
fixed bad loaders
williamFalcon Feb 25, 2020
08eeb48
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
d9cfcdb
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
d2db8f2
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
35a8880
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
cec5931
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
6c412ed
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
6287dd6
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
dc403e8
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
b2755f9
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
1be1cf6
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
83869c2
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
6bc9587
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
66f55d7
fixed error in .fit with loaders
williamFalcon Feb 25, 2020
05ad57d
fixes #909
williamFalcon Feb 25, 2020
9504461
fixes #909
williamFalcon Feb 25, 2020
c12cb92
bug fix
williamFalcon Feb 25, 2020
3173ad3
Fixes #902
williamFalcon Feb 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added automatic sampler setup. Depending on DDP or TPU, lightning configures the sampler correctly (user needs to do nothing) ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926))
- Added `reload_dataloaders_every_epoch=False` flag for trainer. Some users require reloading data every epoch ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926))
- Added `progress_bar_refresh_rate=50` flag for trainer. Throttle refresh rate on notebooks ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926))
- Updated governance docs
- Added a check to ensure that the metric used for early stopping exists before training commences ([#542](https://github.com/PyTorchLightning/pytorch-lightning/pull/542))
- Added `optimizer_idx` argument to `backward` hook ([#733](https://github.com/PyTorchLightning/pytorch-lightning/pull/733))
Expand All @@ -22,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Removed `@data_loader` decorator ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926))
- Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752))
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
Expand Down
6 changes: 3 additions & 3 deletions docs/source/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ Training set-up
- init_optimizers
- configure_apex
- configure_ddp
- get_train_dataloader
- get_test_dataloaders
- get_val_dataloaders
- train_dataloader
- test_dataloaders
- val_dataloaders
- summarize
- restore_weights

Expand Down
20 changes: 9 additions & 11 deletions pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,37 +192,35 @@ def __dataloader(self, train):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=train,
transform=transform, download=True)
transform=transform, download=False)

# when using multi-node (ddp) we need to add the datasampler
train_sampler = None
batch_size = self.hparams.batch_size

if self.use_ddp:
train_sampler = DistributedSampler(dataset)

should_shuffle = train_sampler is None
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=should_shuffle,
sampler=train_sampler,
num_workers=0
)

return loader

@pl.data_loader
def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated

transform=transform, download=True)
dataset = MNIST(root=self.hparams.data_root, train=False,
transform=transform, download=True)

def train_dataloader(self):
log.info('Training data loader called.')
return self.__dataloader(train=True)

@pl.data_loader
def val_dataloader(self):
log.info('Validation data loader called.')
return self.__dataloader(train=False)

@pl.data_loader
def test_dataloader(self):
log.info('Test data loader called.')
return self.__dataloader(train=False)
Expand Down
29 changes: 6 additions & 23 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import traceback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add warning also here

from functools import wraps
import warnings


def data_loader(fn):
Expand All @@ -8,27 +9,9 @@ def data_loader(fn):
:param fn:
:return:
"""
wraps(fn)
attr_name = '_lazy_' + fn.__name__
@wraps(fn)
def _get_data_loader(self):
try:
value = getattr(self, attr_name)
except AttributeError:
try:
value = fn(self) # Lazy evaluation, done only once.
if (
value is not None and
not isinstance(value, list) and
fn.__name__ in ['test_dataloader', 'val_dataloader']
):
value = [value]
except AttributeError as e:
# Guard against AttributeError suppression. (Issue #142)
traceback.print_exc()
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
raise RuntimeError(error) from e
setattr(self, attr_name, value) # Memoize evaluation.
return value
w = 'data_loader decorator deprecated in 0.6.1. Will remove 0.8.0'
warnings.warn(w)

return _get_data_loader
def inner_fx(self):
return fn(self)
return inner_fx
41 changes: 37 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel

try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True

except ImportError:
XLA_AVAILABLE = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather

try:
    import torch_xla.core.xla_model as xm
except ImportError:
    XLA_AVAILABLE = False
else:
    XLA_AVAILABLE = True



class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):

Expand Down Expand Up @@ -798,7 +805,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, sec
optimizer.zero_grad()

"""
if isinstance(optimizer, torch.optim.LBFGS):
if self.trainer.use_tpu and XLA_AVAILABLE:
xm.optimizer_step(optimizer)
elif isinstance(optimizer, torch.optim.LBFGS):
optimizer.step(second_order_closure)
else:
optimizer.step()
Expand Down Expand Up @@ -868,7 +877,33 @@ def tbptt_split_batch(self, batch, split_size):

return splits

@data_loader
def prepare_data(self):
"""Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once

:return: PyTorch DataLoader

This is called before requesting the dataloaders

.. code-block:: python

model.prepare_data()
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()

Example
-------

.. code-block:: python

def prepare_data(self):
download_imagenet()
clean_imagenet()
cache_imagenet()
"""
return None

def train_dataloader(self):
"""Implement a PyTorch DataLoader

Expand Down Expand Up @@ -908,7 +943,6 @@ def tng_dataloader(self): # todo: remove in v0.8.0
" and this method will be removed in v0.8.0", DeprecationWarning)
return output

@data_loader
def test_dataloader(self):
r"""

Expand Down Expand Up @@ -942,7 +976,6 @@ def test_dataloader(self):
"""
return None

@data_loader
def val_dataloader(self):
r"""

Expand Down
Loading