-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save checkpoint and validate every n steps #2534
Comments
Set |
Won’t that do validation every n epochs? |
Oh, you are right, I misunderstand your question. |
As per my understanding, you want to validate the model after every n steps in the same epoch. If I am correct then |
Not really. |
My current workaround is: And add such callback:
|
Ok, your n is |
Let's make this issue into a feature request, allowing |
This is an important feature. Especially for large datasets where an epoch may take a whole day, we might want to save a checkpoint in between epochs in case something goes wrong. We need a way to checkpoint based on steps, or in between epochs. |
Saving a checkpoint every N steps should really not be tied to validation. For some models it doesn't make sense to monitor a decreasing validation loss. For example: vanilla GANs expect a constantly shifting loss value between generator and discriminator. We need independent N-steps checkpointing. |
Also for the opposite; if you have very short epochs, you don't wanna spend time/disk-space saving so many checkpoints. |
Hi all, I believe I figured out how to save every N steps, independent of validation metrics. All you need to do is create a Callback that overrides import os
import pytorch_lightning as pl
class CheckpointEveryNSteps(pl.Callback):
"""
Save a checkpoint every N steps, instead of Lightning's default that checkpoints
based on validation loss.
"""
def __init__(
self,
save_step_frequency,
prefix="N-Step-Checkpoint",
use_modelcheckpoint_filename=False,
):
"""
Args:
save_step_frequency: how often to save in steps
prefix: add a prefix to the name, only used if
use_modelcheckpoint_filename=False
use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
default filename, don't use ours.
"""
self.save_step_frequency = save_step_frequency
self.prefix = prefix
self.use_modelcheckpoint_filename = use_modelcheckpoint_filename
def on_batch_end(self, trainer: pl.Trainer, _):
""" Check if we should save a checkpoint after every train batch """
epoch = trainer.current_epoch
global_step = trainer.global_step
if global_step % self.save_step_frequency == 0:
if self.use_modelcheckpoint_filename:
filename = trainer.checkpoint_callback.filename
else:
filename = f"{self.prefix}_{epoch=}_{global_step=}.ckpt"
ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
trainer.save_checkpoint(ckpt_path)
Trainer(callbacks=[CheckpointEveryNSteps()]) I realize this answers a slightly different question than the original Issue asked for (this doesn't validate), but I'll leave it here because N-Step checkpointing is a common usecase. |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
@andrewjong , that's really helpful. |
and there seems to be
This commit mentions every_n_train_steps, but that results in
I am still forced to use 1.2.5 because of this, since the only solution is to use Edit:
Pretty annoying that I am not able to use pytorch lightning directly with pip anymore. I will probably fork it and add these modifications. Edit2:
but it works with 1.4.9. My modifications are in
Here is my Edit3:
My train class is here. The easier solution is just to downgrade to 1.2.5 and use the I also noticed that to save models / run eval when an keyboard interrupt is called, the function also changed. In 1.2.5 |
I think it would be very useful if Trainer had a parameter similar to check_val_every_n_steps for the case of data sets whose training time per epoch is very long. The CheckpointEveryNSteps callback takes care of saving the model but doesn't run |
@zplizzi you could just comment here... 🐰 |
You can already achieve desired behavior requested in #11468 by using |
Is not the same :(, ModelCheckpoint in the end is linked to validation_step, validation_epoch_end and its log process. val_check_interval only allow you to use float when < 1 but not a specific number of training steps. When val_check_interval > 1 the behaviour is not correct for a non iterable dataset :c. |
I modified the previous provided workaround
to make it work with pytorch-lightning version 1.5.4 (I did not test other versions). The following works for me:
then you need to comment the following lines in function
This workaround allows you to run the validation procedure every n (global) steps, therefore without restarting the counter at every epoch. Additionally, if you also want to checkpoint, you can easily add a ModelCheckpoint callback and set To make the progress bar work properly for validation, you can extend an existing progress bar and override the following property:
|
|
Just to add to the workaround development going on here, I've built on this to get model checkpointing, validation, and early stopping all based on global number of steps. Works on 1.5.4 as well, but I wanted to remove the need for forking the Lightning source: Slightly Modified Callback for Global Step Validation: class ValEveryNSteps(pl.Callback):
def __init__(self, every_n_steps):
self.last_run = None
self.every_n_steps = every_n_steps
def on_batch_end(self, trainer, pl_module):
# Prevent Running validation many times in gradient accumulation
if trainer.global_step == self.last_run:
return
else:
self.last_run = None
if trainer.global_step % self.every_n_steps == 0 and trainer.global_step != 0:
trainer.training = False
stage = trainer.state.stage
trainer.state.stage = RunningStage.VALIDATING
trainer._run_evaluate()
trainer.state.stage = stage
trainer.training = True
trainer.logger_connector._epoch_end_reached = False
self.last_run = trainer.global_step Trainer Args to set everything based on Validation: max_steps = # Max Steps Here
n_steps = # Number of Steps Per Val/Checkpoint/Stopping Check
trainer = Trainer(
...
max_steps=max_steps,
max_epochs=1000,
check_val_every_n_epoch=1000,
callbacks=[
ValEveryNSteps(n_steps),
pl.callbacks.ModelCheckpoint(
monitor="val_loss", save_on_train_epoch_end=False
),
pl.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=5,
verbose=True,
mode="min",
check_on_train_epoch_end=False,
),
]
) |
It is pytorch-lightning 1.6.0 now. Does it support the validate every n steps now? I really want this for training and validating on a large scale dataset. I also want to validate for m steps rather than the whole validation dataset if possible. |
Support has been added with #11993 |
❓ Questions and Help
How to save checkpoint and validate every n steps.
I saw there is a val_check_interval, but it seems it's not for that purpose.
The text was updated successfully, but these errors were encountered: