Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add _is_fresh_start_epoch attribute to fit_loop to avoid reload dl twice when training starts #9614

Conversation

ninginthecloud
Copy link
Contributor

@ninginthecloud ninginthecloud commented Sep 20, 2021

What does this PR do?

Fixes #9502

This PR aims to fix the bug mentioned in issue #9502.
During fit loop, train_dataloader is loaded twice when resuming from checkpoint. This underlying behavior does not meet user's expectation.
There are two options to fix this issue:

  • Option 1 (implemented in this PR): introduce an attribute _is_fresh_start_epoch in fit_loop to track if a new fit_loop is called. _is_fresh_start_epoch can be True when
  1. Trainer. fit(model), current_epoch = 0
  2. Trainer.fit(model, dm), current_epoch != 0 but new dataloader/datamoudle is provided
  3. trainer=Trainer(resume_from_checkpoint=ckpt), Trainer.fit(), current_epoch!=0 but Trainer.fit() resumes from checkpoint
  • Option 2: remove self.reset_train_val_dataloaders(model) in Trainer._run_train() and set the reload logics in corresponding loops. When I worked on this option, I found a lot of other underlying assumptions:
  1. self.num_training_batches = 0 in trainer._setup_on_init(). This value should be initialized as self.num_training_batches = float('inf'), otherwise, if we remove dataloader reloading in Trainer._run_train() before calling fit_loop.run(), the skip flag from fit_loop will be true. which makes fit_loop falsely stops.
  2. we allow users to call trainer.fit() multiple times with different dataloaders. removing dataloader reloading in Trainer._run_train(), it marks fit_loop unable to continue fitting with new data. In this sense, we still need attribute _is_fresh_start_epoch mentioned in Option 1.
  3. reset_val_dataloader() is called before fit_loop officially starts. Sometimes it's called if sanity_check is enabled. naively removing self.reset_train_val_dataloaders(model) in Trainer._run_train() breaks a lot of tests since our usage and assumption of loading val_dataloder() changed. As mentioned in issue Move reload_dataloaders_every_n_epochs to the DataHooks class #8738, it's better to provide more gradually control of reloading train and val dataloader seperately.

In summary, this PR adopts the option 1 to quickly fix this bug. In the next step, let's complete issue #8738 and then refactor dataloader reloading.

Does your PR introduce any breaking changes? If yes, please list them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@@ -1164,6 +1164,12 @@ def _run_train(self) -> None:

self.reset_train_val_dataloaders(model)

# fresh start epoch can be under following situations:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ninginthecloud.

Quick idea to investigate. What if we remove the reset_train_val_dataloaders from the Trainer #1165 and left only the loop handle the reloading logic. This is a blocker for loop customization.

Best,
T.C

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion, @tchaton ! That's what I've mentioned in the description as option 2 for resolving this fix. We definitely can go with that route,and it will change a lot of existing underlying assumptions and behaviors.

Option 2: remove self.reset_train_val_dataloaders(model) in Trainer._run_train() and set the reload logics in corresponding loops. When I worked on this option, I found a lot of other underlying assumptions we need to fix:

  1. self.num_training_batches = 0 in trainer._setup_on_init(). This value should be initialized as self.num_training_batches = float('inf'), otherwise, if we remove dataloader reloading in Trainer._run_train() before calling fit_loop.run(), the skip flag from fit_loop will be true. which makes fit_loop falsely stops.
  2. we allow users to call trainer.fit() multiple times with different dataloaders. removing dataloader reloading in Trainer._run_train(), it marks fit_loop unable to continue fitting with new data. In this sense, we still need attribute _is_fresh_start_epoch mentioned in Option 1.
  3. reset_val_dataloader() is called before fit_loop officially starts. Sometimes it's called if sanity_check is enabled. naively removing self.reset_train_val_dataloaders(model) in Trainer._run_train() breaks a lot of tests since our usage and assumption of loading val_dataloder() changed. As mentioned in issue Move reload_dataloaders_every_n_epochs to the DataHooks class #8738, it's better to provide more gradually control of reloading train and val dataloader seperately.
    In summary, this PR adopts the option 1 to quickly fix this bug. In the next step, let's complete issue Move reload_dataloaders_every_n_epochs to the DataHooks class #8738 and then refactor dataloader reloading.

We agree that ultimately we want to provide loop customization, after all, it does not make sense to let trainer control data reloading. Let me know what you think. I can either 1) postpone issue #8738 to complete option 2, or adopt option 1 and complete #8738 and then refactor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO if #9502 isn't blocking something important, we can go the longer route for option 2 and wait this out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great~ Thanks @carmocca and @tchaton~ Let me update this PR soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ninginthecloud
Copy link
Contributor Author

I will close this PR. Please move all discussion to the latest PR #9671, which implemented option 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dataloader is reloaded twice after resuming from checkpoint
3 participants