We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The reset_{train,val,test}_dataloader in Trainer does not work as intended and leads to silent errors and side effects.
reset_{train,val,test}_dataloader
The problematic lines of code are here: https://github.com/PyTorchLightning/pytorch-lightning/blob/176df202e4e1e5f5101914929f0f3a3608c41f94/pytorch_lightning/trainer/data_loading.py#L447 where a None check prevents attaching the new dataloader.
import torch from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer class RandomDataset(Dataset): def __init__(self, size, length): self.len = length self.data = torch.randn(length, size) def __getitem__(self, index): return self.data[index] def __len__(self): return self.len class BoringModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 2) def forward(self, x): return self.layer(x) def training_step(self, batch, batch_idx): loss = self(batch).sum() self.log("train_loss", loss) return {"loss": loss} def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) def run(): train_data_0 = DataLoader(RandomDataset(32, 128), batch_size=2) train_data_1 = DataLoader(RandomDataset(32, 64), batch_size=2) model = BoringModel() trainer = Trainer(max_epochs=1, weights_summary=None) trainer.fit(model, train_dataloaders=train_data_0) assert trainer.train_dataloader.loaders is train_data_0 # trainer.train_dataloader = None trainer.fit_loop.max_epochs = 2 # here, fit() does not reset the dataloader, the old one is still attached trainer.fit(model, train_dataloaders=train_data_1) # this assertion fails assert trainer.train_dataloader.loaders is train_data_1 if __name__ == '__main__': run()
Assertion does not fail. Second fit attaches correctly the dataloader.
Reported here by user @sid-sundrani
Related to #6030
The text was updated successfully, but these errors were encountered:
Looks like the bug was introduced here #7207 while trying to fix something else.
A simple fix could be to detach all loaders from trainer (by setting to None) when fit() etc. ends. cc @ananthsub
Sorry, something went wrong.
Closing, working in master.
awaelchli
No branches or pull requests
🐛 Bug
The
reset_{train,val,test}_dataloader
in Trainer does not work as intended and leads to silent errors and side effects.The problematic lines of code are here: https://github.com/PyTorchLightning/pytorch-lightning/blob/176df202e4e1e5f5101914929f0f3a3608c41f94/pytorch_lightning/trainer/data_loading.py#L447
where a None check prevents attaching the new dataloader.
Please reproduce using the BoringModel
Expected behavior
Assertion does not fail. Second fit attaches correctly the dataloader.
Additional context
Reported here by user @sid-sundrani
Related to #6030
The text was updated successfully, but these errors were encountered: