Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing dataloader to trainer.fit() doesn't work with tpu (and maybe ddp) #968

Closed
shoarora opened this issue Feb 27, 2020 · 10 comments · Fixed by #971
Closed

Passing dataloader to trainer.fit() doesn't work with tpu (and maybe ddp) #968

shoarora opened this issue Feb 27, 2020 · 10 comments · Fixed by #971
Assignees
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@shoarora
Copy link
Contributor

🐛 Bug

Receive a

AttributeError: Can't pickle local object 'Trainer.__set_fit_dataloaders.<locals>.patch_train_dataloader'

error when passing the dataloader directly to trainer.fit(model, train_loader)

To Reproduce

Steps to reproduce the behavior:

Try to call trainer.fit(model, train_loader) in TPU mode.

(I suspect that anything that calls mp.spawn will cause this problem, so ddp probably will face this issue too.)

Code sample

import os

import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

import torch_xla.core.xla_model as xm


class CoolSystem(pl.LightningModule):
    def __init__(self, use_tpu=False):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.use_tpu = use_tpu
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.0004)


if __name__ == '__main__':
    from pytorch_lightning import Trainer

    model = CoolSystem(use_tpu=True)

    dataset = MNIST(os.getcwd(),
                        train=True,
                        download=True,
                        transform=transforms.ToTensor())

    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    loader = DataLoader(dataset, sampler=sampler, batch_size=32)

    # most basic trainer, uses good defaults
    trainer = Trainer(num_tpu_cores=8)
    trainer.fit(model, loader)

Expected behavior

Ideally, specifying the dataloaders as part of the LightningModule should work just the same as passing the dataloaders into trainer.fit()

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • Docker image: gcr.io/tpu-pytorch/xla:nightly
  • build steps:
    • pip install git+git://github.com/williamFalcon/pytorch-lightning.git@master --upgrade

(I don't have access to the machine right now, so please forgive me on the specific version info temporarily)

Proposed solution

The issue is here, trying to assign a local function to the model

def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders):
        # when dataloader is passed via fit, patch the train_dataloader
        # functions to overwrite with these implementations
        if train_dataloader is not None:
            if not self.is_overriden('training_step', model):
                m = 'You called .fit() with a train_dataloader but did not define training_step()'
                raise MisconfigurationException(m)

            def patch_train_dataloader():
                return train_dataloader

            model.train_dataloader = patch_train_dataloader

Instead of using a closure or a local function, you could use a callable defined at the top-level. This will be pickleable.

class DataLoaderPatcher
    def __init__(self, loader):
        self.loader = loader
    def __call__():
        return self.loader

def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders):
        # when dataloader is passed via fit, patch the train_dataloader
        # functions to overwrite with these implementations
        if train_dataloader is not None:
            if not self.is_overriden('training_step', model):
                m = 'You called .fit() with a train_dataloader but did not define training_step()'
                raise MisconfigurationException(m)

            model.train_dataloader = DataLoaderPatcher(train_dataloader)
@shoarora shoarora added bug Something isn't working help wanted Open to be worked on labels Feb 27, 2020
@shoarora
Copy link
Contributor Author

I don't mind tackling the issue myself, but I'd like some input first. Thanks everyone!

@williamFalcon
Copy link
Contributor

lightning automates the sampler. pass just the dataloader

@shoarora
Copy link
Contributor Author

shoarora commented Feb 27, 2020

that's super cool. I'm giving that a try, both in the GCP XLA docker image and in colab.

It seems to run fine in a colab modified from the TPU-MNIST example, but is still failing in the GCP XLA docker image.

I suppose that's due to the different start_methods?

#  COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)

Reading the flow of trainer.fit(), I don't see anything that would change the local functions set by __set_fit_dataloaders before hitting mp.spawn or xmp.spawn(start_method='spawn').

### excerpt from trainer.fit()

        # set up the passed in dataloaders (if needed)
        ### !! this will set the unpickleable local functions
        self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)

        ...

        elif self.use_ddp:
            if self.is_slurm_managing_tasks:
                task = int(os.environ['SLURM_LOCALID'])
                self.ddp_train(task, model)
            else:
                ### !! I expect this will fail.  need testing.
                mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))

        ...

        elif self.use_tpu:
            log.info(f'training on {self.num_tpu_cores} TPU cores')

            #  COLAB_GPU is an env var available by default in Colab environments.
            start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'

            ### !! using start_method=spawn here requires pickling
            xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)

EDIT:
here is the modified colab I was running. I exported the same thing and tried to run it in non-colab setting and got the failure

@williamFalcon
Copy link
Contributor

so you think the spawn method is causing issues?

@shoarora
Copy link
Contributor Author

so you think the spawn method is causing issues?

Yes. spawn pickles the function and arguments to hand them to a new process, and patch_train_dataloader is not pickleable

@shoarora
Copy link
Contributor Author

image

A quick run of the same example in ddp mode (colab, so non-slurm) results in the same error

@shoarora
Copy link
Contributor Author

From the pickle docs:

The following types can be pickled:

  • None, True, and False
  • integers, floating point numbers, complex numbers
  • strings, bytes, bytearrays
  • tuples, lists, sets, and dictionaries containing only picklable objects
  • functions defined at the top level of a module (using def, not lambda)
  • built-in functions defined at the top level of a module
  • classes that are defined at the top level of a module

The "functions defined at the top level" requirement is what we're failing currently, by defining the patch_train_dataloader() function inside another function.

@williamFalcon
Copy link
Contributor

oh yeah. that’s just one way to do it. we could do the same another way. Basically just wanted to plug it into the framework instead of coming up with new functionality.

want to submit a PR?
or @ethanwharris @Borda

@shoarora
Copy link
Contributor Author

I can take a stab at it with my proposed solution if that's cool

@williamFalcon
Copy link
Contributor

looks great! give it a shot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants