-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add _is_fresh_start_epoch
attribute to fit_loop to avoid reload dl twice when training starts
#9614
add _is_fresh_start_epoch
attribute to fit_loop to avoid reload dl twice when training starts
#9614
Conversation
@@ -1164,6 +1164,12 @@ def _run_train(self) -> None: | |||
|
|||
self.reset_train_val_dataloaders(model) | |||
|
|||
# fresh start epoch can be under following situations: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @ninginthecloud.
Quick idea to investigate. What if we remove the reset_train_val_dataloaders
from the Trainer #1165 and left only the loop handle the reloading logic. This is a blocker for loop customization.
Best,
T.C
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the suggestion, @tchaton ! That's what I've mentioned in the description as option 2 for resolving this fix. We definitely can go with that route,and it will change a lot of existing underlying assumptions and behaviors.
Option 2: remove
self.reset_train_val_dataloaders(model)
inTrainer._run_train()
and set the reload logics in corresponding loops. When I worked on this option, I found a lot of other underlying assumptions we need to fix:
self.num_training_batches = 0
intrainer._setup_on_init()
. This value should be initialized asself.num_training_batches = float('inf')
, otherwise, if we remove dataloader reloading inTrainer._run_train()
before calling fit_loop.run(), theskip
flag fromfit_loop
will be true. which makesfit_loop
falsely stops.- we allow users to call
trainer.fit()
multiple times with different dataloaders. removing dataloader reloading inTrainer._run_train()
, it marksfit_loop
unable to continue fitting with new data. In this sense, we still need attribute_is_fresh_start_epoch
mentioned in Option 1.reset_val_dataloader()
is called before fit_loop officially starts. Sometimes it's called if sanity_check is enabled. naively removingself.reset_train_val_dataloaders(model)
inTrainer._run_train()
breaks a lot of tests since our usage and assumption of loadingval_dataloder()
changed. As mentioned in issue Move reload_dataloaders_every_n_epochs to the DataHooks class #8738, it's better to provide more gradually control of reloading train and val dataloader seperately.
In summary, this PR adopts the option 1 to quickly fix this bug. In the next step, let's complete issue Move reload_dataloaders_every_n_epochs to the DataHooks class #8738 and then refactor dataloader reloading.
We agree that ultimately we want to provide loop customization, after all, it does not make sense to let trainer control data reloading. Let me know what you think. I can either 1) postpone issue #8738 to complete option 2, or adopt option 1 and complete #8738 and then refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO if #9502 isn't blocking something important, we can go the longer route for option 2 and wait this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ninginthecloud
I will close this PR. Please move all discussion to the latest PR #9671, which implemented option 2. |
What does this PR do?
Fixes #9502
This PR aims to fix the bug mentioned in issue #9502.
During fit loop, train_dataloader is loaded twice when resuming from checkpoint. This underlying behavior does not meet user's expectation.
There are two options to fix this issue:
_is_fresh_start_epoch
in fit_loop to track if a new fit_loop is called._is_fresh_start_epoch
can beTrue
whenself.reset_train_val_dataloaders(model)
inTrainer._run_train()
and set the reload logics in corresponding loops. When I worked on this option, I found a lot of other underlying assumptions:self.num_training_batches = 0
intrainer._setup_on_init()
. This value should be initialized asself.num_training_batches = float('inf')
, otherwise, if we remove dataloader reloading inTrainer._run_train()
before calling fit_loop.run(), theskip
flag fromfit_loop
will be true. which makesfit_loop
falsely stops.trainer.fit()
multiple times with different dataloaders. removing dataloader reloading inTrainer._run_train()
, it marksfit_loop
unable to continue fitting with new data. In this sense, we still need attribute_is_fresh_start_epoch
mentioned in Option 1.reset_val_dataloader()
is called before fit_loop officially starts. Sometimes it's called if sanity_check is enabled. naively removingself.reset_train_val_dataloaders(model)
inTrainer._run_train()
breaks a lot of tests since our usage and assumption of loadingval_dataloder()
changed. As mentioned in issue Move reload_dataloaders_every_n_epochs to the DataHooks class #8738, it's better to provide more gradually control of reloading train and val dataloader seperately.In summary, this PR adopts the option 1 to quickly fix this bug. In the next step, let's complete issue #8738 and then refactor dataloader reloading.
Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃