Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed accelerator calls datamodule.train_dataloader() prior to setup() #8872

Closed
leezu opened this issue Aug 12, 2021 · 4 comments · Fixed by #9221
Closed

Deepspeed accelerator calls datamodule.train_dataloader() prior to setup() #8872

leezu opened this issue Aug 12, 2021 · 4 comments · Fixed by #9221
Assignees
Labels
3rd party Related to a 3rd-party bug Something isn't working data handling Generic data-related topic help wanted Open to be worked on

Comments

@leezu
Copy link
Contributor

leezu commented Aug 12, 2021

🐛 Bug

Deepspeed accelerator calls datamodule.train_dataloader() prior to setup(). This does not happen with other accelerators.

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class PlDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self._setup = False

    def setup(self, stage):
        self._setup = True

    def train_dataloader(self):
        assert self._setup
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        assert self._setup
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def test_dataloader(self):
        assert self._setup
        return DataLoader(RandomDataset(32, 64), batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
        accelerator="deepspeed",
    )
    trainer.fit(model, datamodule=PlDataModule())


if __name__ == "__main__":
    run()

Expected behavior

train_dataloader is never called before setup.

Additional context

Backtrace

  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py(553)fit()
-> self._run(model)
  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py(865)_run()
-> self.accelerator.setup_environment()
  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu.py(30)setup_environment()
-> super().setup_environment()
  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py(76)setup_environment()
-> self.training_type_plugin.setup_environment()
  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py(166)setup_environment()
-> self.setup_distributed()
  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py(341)setup_distributed()
-> self._format_config()
  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py(545)_format_config()
-> self._format_batch_size_and_grad_accum_config()
  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py(555)_format_batch_size_and_grad_accum_config()
-> batch_size = self._auto_select_batch_size()
  /home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py(566)_auto_select_batch_size()
-> train_dataloader = self.lightning_module.train_dataloader()
@leezu leezu added bug Something isn't working help wanted Open to be worked on labels Aug 12, 2021
@Borda Borda added the 3rd party Related to a 3rd-party label Aug 13, 2021
@Borda Borda added the data handling Generic data-related topic label Aug 13, 2021
@tchaton tchaton self-assigned this Aug 16, 2021
@tchaton
Copy link
Contributor

tchaton commented Aug 27, 2021

Dear @leezu,

The DeepSpeed Plugin is doing a hack to automatically resolved your batch_size, but won't actually use this dataloader for training.
And there is no simple way to improve this unless you provide the batch_size within the deepspeed config.

And, the DataLoader used for training will be defined within a distributed setting.

My advice is to use is_distributed_available() function if you have custom samplers.

I will be closing this issue for now.

Best,
T.C

@tchaton tchaton closed this as completed Aug 27, 2021
@leezu
Copy link
Contributor Author

leezu commented Aug 27, 2021

@tchaton thank you for looking into a fix.

And there is no simple way to improve this unless you provide the batch_size within the deepspeed config.

Even if setup can't be called automatically, how about adding sanity assertions that ensure train_dataloader is never called before setup? In that case users can at least be informed about the issue

@tchaton
Copy link
Contributor

tchaton commented Aug 30, 2021

@SeanNaren Any idea ?

@SeanNaren
Copy link
Contributor

SeanNaren commented Aug 31, 2021

Thanks for the issue @leezu!

To skip the auto-infer of the batch size for logging, you can pass the batch size directly to the plugin like such:

Trainer(plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=32))

If I modify the case, it works :)

import os

import torch
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.plugins import DeepSpeedPlugin
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class PlDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self._setup = False

    def setup(self, stage):
        self._setup = True

    def train_dataloader(self):
        assert self._setup
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        assert self._setup
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def test_dataloader(self):
        assert self._setup
        return DataLoader(RandomDataset(32, 64), batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
        plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=32),
        gpus=1
    )
    trainer.fit(model, datamodule=PlDataModule())


if __name__ == "__main__":
    run()

This should skip the auto-infer of the batch size. We should add a warning here so that users are aware! Will add that in a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working data handling Generic data-related topic help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants