Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inconsistent step counting for train and val after resuming from checkpoint #5547

Closed
AsaphLightricks opened this issue Jan 17, 2021 · 11 comments
Labels
logger Related to the Loggers won't fix This will not be worked on

Comments

@AsaphLightricks
Copy link

https://github.com/PyTorchLightning/pytorch-lightning/blob/9ebbfece5e2c56bb5300cfffafb129e399492469/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py#L187-L193

I have an issue where if i resume a run from a checkpoint, the step being passed to my WandbLogger is inconsistent.

After resuming, in training steps the step count is set to be batch_idx which starts from 0 even if the training was resumed.
However for validation steps the step count is set to be self.trainer.global_step which starts from the recovered global_step from the checkpoint.
Why isn't the step count during training isn't set to self.trainer.global_step as well?

This discrepancy is causing my WandbLogger to drop logs. During the first epoch of a resumed run the steps counts being passed to the WandbLogger are 0, 1, 2, ... (because they are the batch_idx), but then, during validation they suddenly jump to be self.trainer.global_step which is much higher because the run was resumed. Then, as the second epoch starts, the steps count goes back to be batch_idx, which is lower than the self.trainer.global_step. Then, the WandbLogger sees that the internal step is not monotonically increasing, and drops the logs.

@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@awaelchli
Copy link
Contributor

Was it solved here? #5050

@awaelchli awaelchli added information needed logger Related to the Loggers labels Jan 18, 2021
@awaelchli
Copy link
Contributor

also, perhaps at the same time you may need this feature #5351 to avoid dropping logs during validation :)

@AsaphLightricks
Copy link
Author

Hi @awaelchli,

Unfortunately, no and no.
The first PR didn't solve it (the code i'm using now is updated to include it).
The second PR does provide a 'hacky' solution, because it will prevent the dropping of logs, and wandb supports custom x-axis for their graphs so i can force the step to be the global step.
However the underlying issue is still there, what's the thinking behind returning the global_step on validation_step() but batch_idx on train_step(). Is this an oversight? if not what's the purpose of this decision?

@borisdayma
Copy link
Contributor

borisdayma commented Jan 20, 2021

Right, the first PR makes sure we start logging after the last logged step from resumed run but it assumed we would start logging at 0 using a new Trainer (even when reloading a model).

However the underlying issue is still there, what's the thinking behind returning the global_step on validation_step() but batch_idx on train_step(). Is this an oversight? if not what's the purpose of this decision?

@AsaphLightricks I'm curious if you were able to test this change to see if it would work in your case?

@AsaphLightricks
Copy link
Author

@borisdayma, I just tested it, and it solves the problem. I commented out global_step and replaced it with batch_idx:

elif step is None:
    # added metrics by Lightning for convenience
    if log_train_step_metrics:
        step = self.trainer.total_batch_idx
    else:
        scalar_metrics['epoch'] = self.trainer.current_epoch
        # step = self.trainer.global_step
        step = self.trainer.total_batch_idx

with this change, i was able to resume a run from a checkpoint and the wandb logger works just fine (not dropping logs).

@awaelchli
Copy link
Contributor

awaelchli commented Jan 21, 2021

I cannot say for 100% (I will test later), but if I remember correctly then global_step is resumed (so not restarting at 0) but total_batch_idx is reset. It is also not the same as batch_idx in training_step.
Also, total_batch_idx will be larger than global step when using accumulated gradients.

@borisdayma
Copy link
Contributor

Hi, I'm just checking if someone from PL team knows why the step is different or if we can just make the change proposed by Asaph.

@stale
Copy link

stale bot commented Feb 27, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 27, 2021
@stale stale bot closed this as completed Mar 7, 2021
@awaelchli awaelchli reopened this Mar 7, 2021
@stale stale bot removed the won't fix This will not be worked on label Mar 7, 2021
@awaelchli
Copy link
Contributor

No, we cannot implement the change proposed to @AsaphLightricks
In general, the total_batch_idx is different from global_step and we log on global steps (normally).
This looks like a matter of resuming wandb with the correct initial step. And the issue with step not monotonically increasing was solved in #5931 by @borisdayma right?

@stale
Copy link

stale bot commented Apr 6, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 6, 2021
@stale stale bot closed this as completed Apr 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
logger Related to the Loggers won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants