-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Passing dataloader to trainer.fit() doesn't work with tpu (and maybe ddp) #968
Comments
I don't mind tackling the issue myself, but I'd like some input first. Thanks everyone! |
lightning automates the sampler. pass just the dataloader |
that's super cool. I'm giving that a try, both in the GCP XLA docker image and in colab. It seems to run fine in a colab modified from the TPU-MNIST example, but is still failing in the GCP XLA docker image. I suppose that's due to the different # COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method) Reading the flow of ### excerpt from trainer.fit()
# set up the passed in dataloaders (if needed)
### !! this will set the unpickleable local functions
self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)
...
elif self.use_ddp:
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
else:
### !! I expect this will fail. need testing.
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
...
elif self.use_tpu:
log.info(f'training on {self.num_tpu_cores} TPU cores')
# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'
### !! using start_method=spawn here requires pickling
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method) EDIT: |
so you think the spawn method is causing issues? |
Yes. |
From the pickle docs: The following types can be pickled:
The "functions defined at the top level" requirement is what we're failing currently, by defining the |
oh yeah. that’s just one way to do it. we could do the same another way. Basically just wanted to plug it into the framework instead of coming up with new functionality. want to submit a PR? |
I can take a stab at it with my proposed solution if that's cool |
looks great! give it a shot |
🐛 Bug
Receive a
error when passing the dataloader directly to
trainer.fit(model, train_loader)
To Reproduce
Steps to reproduce the behavior:
Try to call
trainer.fit(model, train_loader)
in TPU mode.(I suspect that anything that calls
mp.spawn
will cause this problem, so ddp probably will face this issue too.)Code sample
Expected behavior
Ideally, specifying the dataloaders as part of the LightningModule should work just the same as passing the dataloaders into
trainer.fit()
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
(I don't have access to the machine right now, so please forgive me on the specific version info temporarily)
Proposed solution
The issue is here, trying to assign a local function to the model
Instead of using a closure or a local function, you could use a callable defined at the top-level. This will be pickleable.
The text was updated successfully, but these errors were encountered: