Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework dataloader reset logic in Trainer #8435

Closed
awaelchli opened this issue Jul 15, 2021 · 2 comments
Closed

rework dataloader reset logic in Trainer #8435

awaelchli opened this issue Jul 15, 2021 · 2 comments
Assignees
Labels
bug Something isn't working help wanted Open to be worked on let's do it! approved to implement
Milestone

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Jul 15, 2021

🐛 Bug

The reset_{train,val,test}_dataloader in Trainer does not work as intended and leads to silent errors and side effects.

The problematic lines of code are here: https://github.com/PyTorchLightning/pytorch-lightning/blob/176df202e4e1e5f5101914929f0f3a3608c41f94/pytorch_lightning/trainer/data_loading.py#L447
where a None check prevents attaching the new dataloader.

Please reproduce using the BoringModel

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data_0 = DataLoader(RandomDataset(32, 128), batch_size=2)
    train_data_1 = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(max_epochs=1, weights_summary=None)
    trainer.fit(model, train_dataloaders=train_data_0)
    assert trainer.train_dataloader.loaders is train_data_0
    # trainer.train_dataloader = None

    trainer.fit_loop.max_epochs = 2
    # here, fit() does not reset the dataloader, the old one is still attached
    trainer.fit(model, train_dataloaders=train_data_1)
    
    # this assertion fails
    assert trainer.train_dataloader.loaders is train_data_1


if __name__ == '__main__':
    run()

Expected behavior

Assertion does not fail. Second fit attaches correctly the dataloader.

Additional context

Reported here by user @sid-sundrani

Related to #6030

@awaelchli awaelchli added bug Something isn't working help wanted Open to be worked on labels Jul 15, 2021
@awaelchli awaelchli self-assigned this Jul 15, 2021
@awaelchli awaelchli added the priority: 0 High priority task label Jul 15, 2021
@awaelchli awaelchli added this to the v1.4.x milestone Jul 15, 2021
@awaelchli
Copy link
Contributor Author

awaelchli commented Jul 16, 2021

Looks like the bug was introduced here #7207 while trying to fix something else.

A simple fix could be to detach all loaders from trainer (by setting to None) when fit() etc. ends.
cc @ananthsub

@awaelchli awaelchli removed the priority: 0 High priority task label Jul 21, 2021
@tchaton tchaton added the let's do it! approved to implement label Sep 10, 2021
@awaelchli awaelchli modified the milestones: v1.4.x, 1.5.x Nov 3, 2021
@carmocca
Copy link
Contributor

carmocca commented Mar 1, 2022

Closing, working in master.

@carmocca carmocca closed this as completed Mar 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on let's do it! approved to implement
Projects
None yet
Development

No branches or pull requests

3 participants