-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deepspeed accelerator calls datamodule.train_dataloader() prior to setup() #8872
Comments
Dear @leezu, The And, the DataLoader used for training will be defined within a distributed setting. My advice is to use I will be closing this issue for now. Best, |
@tchaton thank you for looking into a fix.
Even if |
@SeanNaren Any idea ? |
Thanks for the issue @leezu! To skip the auto-infer of the batch size for logging, you can pass the batch size directly to the plugin like such: Trainer(plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=32)) If I modify the case, it works :) import os
import torch
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.plugins import DeepSpeedPlugin
from torch.utils.data import DataLoader, Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class PlDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self._setup = False
def setup(self, stage):
self._setup = True
def train_dataloader(self):
assert self._setup
return DataLoader(RandomDataset(32, 64), batch_size=2)
def val_dataloader(self):
assert self._setup
return DataLoader(RandomDataset(32, 64), batch_size=2)
def test_dataloader(self):
assert self._setup
return DataLoader(RandomDataset(32, 64), batch_size=2)
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=32),
gpus=1
)
trainer.fit(model, datamodule=PlDataModule())
if __name__ == "__main__":
run() This should skip the auto-infer of the batch size. We should add a warning here so that users are aware! Will add that in a PR |
🐛 Bug
Deepspeed accelerator calls datamodule.train_dataloader() prior to setup(). This does not happen with other accelerators.
To Reproduce
Expected behavior
train_dataloader
is never called beforesetup
.Additional context
Backtrace
The text was updated successfully, but these errors were encountered: