Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable saving and loading stateful DataLoaders in Trainer #19361

Merged
merged 33 commits into from
Feb 1, 2024

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Jan 29, 2024

What does this PR do?

Saves the state_dict of training dataloaders into the checkpoint, and enables that state to be loaded when resuming from a checkpoint. An example of a stateful dataloader is lightning.data.StreamingDataLoader.

The implementation collects the state of all iterables under the CombinedLoader that follow the stateful interface. The states are collected over the flattened view (CombinedLoader.flattened) and stored in a last. They get restored in the same manner via the flattened view. This means the number of iterables and the order in which they are given must be exactly the same as when the checkpoint was saved, otherwise the loading will fail.

Fixes #17105
Closes #17543

cc @Borda @justusschock @awaelchli @carmocca

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jan 29, 2024
@awaelchli awaelchli added feature Is an improvement or enhancement data handling Generic data-related topic trainer fun Staff contributions outside working hours - to differentiate from the "community" label and removed pl Generic label for PyTorch Lightning package labels Jan 29, 2024
@awaelchli awaelchli added this to the 2.2 milestone Jan 29, 2024
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jan 29, 2024
@awaelchli awaelchli force-pushed the feature/stateful-dataloader branch from 1d75450 to 5317545 Compare January 29, 2024 02:45
@awaelchli awaelchli changed the title Save state_dict for stateful DataLoaders Enable saving and loading stateful DataLoaders in Trainer Jan 30, 2024
Copy link

codecov bot commented Jan 30, 2024

Codecov Report

Merging #19361 (0f8621a) into master (5d178d0) will decrease coverage by 35%.
The diff coverage is 100%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19361      +/-   ##
==========================================
- Coverage      84%      49%     -35%     
==========================================
  Files         448      440       -8     
  Lines       37887    37762     -125     
==========================================
- Hits        31649    18392   -13257     
- Misses       6238    19370   +13132     

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ;)

@mergify mergify bot added the ready PRs ready to be merged label Jan 30, 2024
src/lightning/pytorch/loops/fit_loop.py Outdated Show resolved Hide resolved
src/lightning/pytorch/loops/fit_loop.py Outdated Show resolved Hide resolved
@awaelchli awaelchli requested a review from carmocca February 1, 2024 00:28
@mergify mergify bot added has conflicts and removed ready PRs ready to be merged labels Feb 1, 2024
@mergify mergify bot added ready PRs ready to be merged and removed has conflicts ready PRs ready to be merged labels Feb 1, 2024
@awaelchli awaelchli merged commit 34a34a0 into master Feb 1, 2024
96 of 97 checks passed
@awaelchli awaelchli deleted the feature/stateful-dataloader branch February 1, 2024 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data handling Generic data-related topic fault tolerance feature Is an improvement or enhancement fun Staff contributions outside working hours - to differentiate from the "community" label pl Generic label for PyTorch Lightning package ready PRs ready to be merged trainer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Resuming training gives different model result / weights Support manual dataloader fault-tolerance
4 participants