-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up dataloader logic #926
Conversation
Hello @williamFalcon! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-02-25 03:14:52 UTC |
dl_args = { | ||
'dataset': dataloader.dataset, | ||
'batch_size': dataloader.batch_size, | ||
'shuffle': False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if a user wants to shuffle batches (when running on a single machine)? i see below that in certain cases you're re-setting this value to False
, did you intend to have it set to True
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather move the shuffle to arguments as the others are taken from dataloader
and only this is fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is really huge so it is just my quick comments...
def prepare_data(self): | ||
transform = transforms.Compose([transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (1.0,))]) | ||
dataset = MNIST(root=self.hparams.data_root, train=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicated
@@ -1,5 +1,6 @@ | |||
import traceback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add warning also here
XLA_AVAILABLE = True | ||
|
||
except ImportError: | ||
XLA_AVAILABLE = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rather
try:
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
dl_args = { | ||
'dataset': dataloader.dataset, | ||
'batch_size': dataloader.batch_size, | ||
'shuffle': False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather move the shuffle to arguments as the others are taken from dataloader
and only this is fixed
if train: | ||
if self.use_ddp or self.use_ddp2: | ||
sampler = DistributedSampler(dataloader.dataset) | ||
dl_args['shuffle'] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this if it is already fixed as false?
warnings.warn(msg) | ||
break | ||
|
||
def init_test_dataloader(self, model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that this can be simply unified ass the content is almost the same
def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders): | ||
# when dataloader is passed via fit, patch the train_dataloader | ||
# functions to overwrite with these implementations | ||
if train_dataloader is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may be unified...
def prepare_data(self): | ||
transform = transforms.Compose([transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (1.0,))]) | ||
dataset = TestingMNIST(root=self.hparams.data_root, train=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicated
|
||
# acc | ||
labels_hat = torch.argmax(y_hat, dim=1) | ||
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it tensor already here
return output | ||
|
||
|
||
class LightningTestFitMultipleTestDataloadersMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is not easy to see, what is the difference to LightningTestFitSingleTestDataloadersMixin
I still feel this puts too much restriction on the data loader. |
i don’t disagree. maybe a good approach is to check that it’s a pytorch dataloader? what other dataloaders are there? |
What I meant is that lightning should not touch the data loader that the user provides unless necessary. I can open up an issue if you think that's a better place for this discussion. |
we could make this a method you can override in lightning module. what use case do you need to maintain the original loader? we could also use a flag in the trainer: |
I have multiple dataloaders that each loads images in order. The auto_add_sampler functions says it shouldn't do anything when user provides a sampler, we should at least fix this part. |
We're currently adressing this in the fix for #953 - will PR soon. The solution is to re-write the |
@ethanwharris I will repost my concerns there then. Thanks for the pointer. |
@versatran01 that would be cool, looking forward to your points 🤖 |
* added get dataloaders directly using a getter * deleted decorator * added prepare_data hook * refactored dataloader init * refactored dataloader init * added dataloader reset flag and main loop * added dataloader reset flag and main loop * added dataloader reset flag and main loop * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixes Lightning-AI#909 * fixes Lightning-AI#909 * bug fix * Fixes Lightning-AI#902
Fixes #928
Fixes #927
Fixes #922
Fixes #909
Fixes #859
Fixes #902
Removes data_decorator
Adds prepare_data
Lightning needs a step to download data on proc 0 only
Added new flags
Fixes .fit with data
The .fit(dataloaders) was buggy. Simplified it to just hook into the rest of the framework instead of its own adhoc process.
Automatic sampler
Now user doesn't have to mess around with samplers on DDP or TPUs. Lightning sets it up automatically.