Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save checkpoint and validate every n steps #2534

Closed
ruotianluo opened this issue Jul 7, 2020 · 24 comments · Fixed by #6146
Closed

Save checkpoint and validate every n steps #2534

ruotianluo opened this issue Jul 7, 2020 · 24 comments · Fixed by #6146
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@ruotianluo
Copy link
Contributor

❓ Questions and Help

How to save checkpoint and validate every n steps.

I saw there is a val_check_interval, but it seems it's not for that purpose.

@ruotianluo ruotianluo added the question Further information is requested label Jul 7, 2020
@nanaHa1003
Copy link

Set check_val_every_n_epoch=n in Trainer may do the trick.

@ruotianluo
Copy link
Contributor Author

Won’t that do validation every n epochs?

@nanaHa1003
Copy link

Oh, you are right, I misunderstand your question.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Jul 7, 2020

As per my understanding, you want to validate the model after every n steps in the same epoch. If I am correct then val_check_interval does the same:
https://github.com/PyTorchLightning/pytorch-lightning/blob/7ef73f242ad4e5f14e6c967c8639c3a65285a048/pytorch_lightning/trainer/evaluation_loop.py#L52-L68

@ruotianluo
Copy link
Contributor Author

Not really.
First val_check_interval can't be bigger than the number of training batch.
Second, it is the number of steps in one epoch. that is, do when batch_idx % val_check_interval == 0 instead of global_step % val_check_interval == 0.

@ruotianluo
Copy link
Contributor Author

My current workaround is:
set check_val_every_n_epoch to be float('inf').

And add such callback:


class ValEveryNSteps(pl.Callback):
    def __init__(self, every_n_step):
        self.every_n_step = every_n_step

    def on_batch_end(self, trainer, pl_module):
        if trainer.global_step % self.every_n_step == 0 and trainer.global_step != 0:
            trainer.run_evaluation(test_mode=False)

@rohitgr7
Copy link
Contributor

rohitgr7 commented Jul 7, 2020

Ok, your n is global_steps.

@awaelchli
Copy link
Contributor

Let's make this issue into a feature request, allowing val_check_interval > len(train_dataloader)?

@awaelchli awaelchli added feature Is an improvement or enhancement help wanted Open to be worked on and removed question Further information is requested labels Jul 8, 2020
@andrewjong
Copy link

andrewjong commented Aug 11, 2020

This is an important feature. Especially for large datasets where an epoch may take a whole day, we might want to save a checkpoint in between epochs in case something goes wrong. We need a way to checkpoint based on steps, or in between epochs.

@andrewjong
Copy link

andrewjong commented Aug 12, 2020

Saving a checkpoint every N steps should really not be tied to validation.

For some models it doesn't make sense to monitor a decreasing validation loss. For example: vanilla GANs expect a constantly shifting loss value between generator and discriminator. We need independent N-steps checkpointing.

@bryant1410
Copy link
Contributor

This is an important feature. Especially for large datasets where an epoch may take a whole day, we might want to save a checkpoint in between epochs in case something goes wrong.

Also for the opposite; if you have very short epochs, you don't wanna spend time/disk-space saving so many checkpoints.

@andrewjong
Copy link

andrewjong commented Aug 16, 2020

Hi all, I believe I figured out how to save every N steps, independent of validation metrics. All you need to do is create a Callback that overrides on_batch_end in v0.8.5 (or on_train_batch_end in v0.9+). The below code will save to the same directory as other checkpoints.

import os
import pytorch_lightning as pl

class CheckpointEveryNSteps(pl.Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
        self,
        save_step_frequency,
        prefix="N-Step-Checkpoint",
        use_modelcheckpoint_filename=False,
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix
        self.use_modelcheckpoint_filename = use_modelcheckpoint_filename

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            if self.use_modelcheckpoint_filename:
                filename = trainer.checkpoint_callback.filename
            else:
                filename = f"{self.prefix}_{epoch=}_{global_step=}.ckpt"
            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)

Trainer(callbacks=[CheckpointEveryNSteps()])

I realize this answers a slightly different question than the original Issue asked for (this doesn't validate), but I'll leave it here because N-Step checkpointing is a common usecase.

@stale
Copy link

stale bot commented Nov 17, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Nov 17, 2020
@stale stale bot closed this as completed Nov 24, 2020
@nrjvarshney
Copy link

nrjvarshney commented Feb 9, 2021

@andrewjong , that's really helpful.
I want to plot the learning curve for all training, validation, and test datasets. It would be great if you could help me out.
Essentially, I don't want to save the model but evaluate the val and test datasets using the model after every n steps.
I can use Trainer(val_check_interval=0.25) for the validation set but what about the test set and is there an easier way to directly plot the curve is tensorboard?

@styler00dollar
Copy link

styler00dollar commented Oct 21, 2021

trainer.run_evaluation() does not exist in 1.4.9

AttributeError: 'Trainer' object has no attribute 'run_evaluation'

and there seems to be trainer._run_evaluate(), but that results in

File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1074, in _run_evaluate
    assert self.evaluating
AssertionError

This commit mentions every_n_train_steps, but that results in

File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/env_vars_connector.py", line 40, in insert_env_defaults
    return fn(self, **kwargs)
TypeError: __init__() got an unexpected keyword argument 'every_n_train_steps'

I am still forced to use 1.2.5 because of this, since the only solution is to use CheckpointEveryNSteps and it works with that version, but ddp seems broken without updating. Is there any easy way to do iteration based eval and checkpointing in new versions?

Edit:
Seems like there is only one real solution.

  1. Use CheckpointEveryNSteps from the comment above, but replace trainer.run_evaluation() with trainer._run_evaluate().
  2. Go inside /usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py and comment the line with assert self.evaluating inside _run_evaluate().
  3. Go inside /usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.pyand commend all assert not self._epoch_end_reached.

Pretty annoying that I am not able to use pytorch lightning directly with pip anymore. I will probably fork it and add these modifications.

Edit2:
Tried to add that into master (1.5), but that results in

  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 110, in advance
    if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device:
AttributeError: 'NoneType' object has no attribute 'store_on_device'

but it works with 1.4.9. My modifications are in

pip install git+https://github.com/styler00dollar/pytorch-lightning.git@fc86f4ca817d5ba1702a210a898ac2729c870112

Here is my checkpoint.py with my modified CheckpointEveryNSteps if anybody is curious. It does checkpoint/val after a certain iteration amount and also saves the models when a keyboard interrupt is triggered (stopping training).

Edit3:
If you do this, you probably also need to manually count up the iterations, since that seems stuck. And if it never counts up the iteration, it will never save.

  def __init__(self):
    super().__init__()
    self.iter_check = 0

  def training_step(self, train_batch, batch_idx, optimizer_idx=0):
      # iteration count is sometimes broken, adding a check and manual increment
      # only increment if generator gets trained (loop gets called a second time for discriminator)
      if self.trainer.global_step != 0:
        if optimizer_idx == 0 and self.iter_check == self.trainer.global_step:
          self.trainer.global_step += 1
        self.iter_check = self.trainer.global_step

My train class is here.

The easier solution is just to downgrade to 1.2.5 and use the CheckpointEveryNSteps with trainer.run_evaluation() with unmodified pytorch lightning and no iteration counting up workaround.

I also noticed that to save models / run eval when an keyboard interrupt is called, the function also changed. In 1.2.5 def on_train_end(self, trainer, pl_module): gets added to CheckpointEveryNSteps, while in 1.4.9 it is def on_keyboard_interrupt(self, trainer, pl_module):.

@camilomarino
Copy link

I think it would be very useful if Trainer had a parameter similar to check_val_every_n_steps for the case of data sets whose training time per epoch is very long.

The CheckpointEveryNSteps callback takes care of saving the model but doesn't run validation_step, is there a way to run the validation_step method every n steps?

@Borda
Copy link
Member

Borda commented Jan 13, 2022

@zplizzi you could just comment here... 🐰

@Borda Borda reopened this Jan 13, 2022
@stale stale bot removed the won't fix This will not be worked on label Jan 13, 2022
@EricWiener
Copy link
Contributor

You can already achieve desired behavior requested in #11468 by using ModelCheckpoint with every_n_train_steps specified + specifying val_check_interval in the Trainer.

@jumpynitro
Copy link

Is not the same :(, ModelCheckpoint in the end is linked to validation_step, validation_epoch_end and its log process. val_check_interval only allow you to use float when < 1 but not a specific number of training steps. When val_check_interval > 1 the behaviour is not correct for a non iterable dataset :c.

@AlessioQuercia
Copy link
Contributor

AlessioQuercia commented Feb 18, 2022

I modified the previous provided workaround

My current workaround is: set check_val_every_n_epoch to be float('inf').

And add such callback:


class ValEveryNSteps(pl.Callback):
    def __init__(self, every_n_step):
        self.every_n_step = every_n_step

    def on_batch_end(self, trainer, pl_module):
        if trainer.global_step % self.every_n_step == 0 and trainer.global_step != 0:
            trainer.run_evaluation(test_mode=False)

to make it work with pytorch-lightning version 1.5.4 (I did not test other versions). The following works for me:

set check_val_every_n_epoch to be float('inf')

class ValEveryNSteps(Callback):
    def __init__(self, every_n_steps):
        self.every_n_steps = every_n_steps

    def on_train_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, unused=0
    ):
        if trainer.global_step % self.every_n_steps == 0 and trainer.global_step != 0:
            trainer.training = False
            stage = trainer.state.stage
            trainer.state.stage = RunningStage.VALIDATING
            trainer._run_evaluate()
            trainer.state.stage = stage
            trainer.training = True
            trainer.logger_connector._epoch_end_reached = False

then you need to comment the following lines in function reset_train_dataloader in pytorch_lightning\trainer\data_loading.py at lines 467-474 (in the master pytorch lighting version you can find this inside class pytorch_lightning\trainer\trainer.py at lines 1823-1830):

        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            #if self.val_check_batch > self.num_training_batches:
            #    raise ValueError(
            #        f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
            #        f"to the number of the training batches ({self.num_training_batches}). "
            #        "If you want to disable validation set `limit_val_batches` to 0.0 instead."
            #    )

This workaround allows you to run the validation procedure every n (global) steps, therefore without restarting the counter at every epoch. Additionally, if you also want to checkpoint, you can easily add a ModelCheckpoint callback and set every_n_val_epochs=1.

To make the progress bar work properly for validation, you can extend an existing progress bar and override the following property:

class MyProgressBar(ProgressBar):
    @property
    def total_val_batches(self) -> int:
        """The total number of validation batches during validation, which may change from epoch to epoch.

        Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
        dataloader is of infinite size.
        """
        total_val_batches = 0
        if self.trainer.enable_validation:
            is_val_epoch = (
                self.trainer.check_val_every_n_epoch == float("inf")
                or (self.trainer.current_epoch + 1)
                % self.trainer.check_val_every_n_epoch
                == 0
            )
            total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0

        return total_val_batches

@tshu-w
Copy link
Contributor

tshu-w commented Mar 22, 2022

This commit mentions every_n_train_steps, but that results in

every_n_train_steps should be pass in ModelCheckpoint

@Helw150
Copy link

Helw150 commented Apr 1, 2022

Just to add to the workaround development going on here, I've built on this to get model checkpointing, validation, and early stopping all based on global number of steps. Works on 1.5.4 as well, but I wanted to remove the need for forking the Lightning source:

Slightly Modified Callback for Global Step Validation:

class ValEveryNSteps(pl.Callback):                                                                                                                                                                                 
    def __init__(self, every_n_steps):                                                                                                                                                                             
        self.last_run = None                                                                                                                                                                                      
        self.every_n_steps = every_n_steps                                                                                                                                                                         
                                                                                                                                                                                                                   
    def on_batch_end(self, trainer, pl_module):                                                                                                                                                                    
        # Prevent Running validation many times in gradient accumulation                                                                                                                                           
        if trainer.global_step == self.last_run:                                                                                                                                                                   
            return                                                                                                                                                                                                 
        else:                                                                                                                                                                                                      
            self.last_run = None                                                                                                                                                                                   
        if trainer.global_step % self.every_n_steps == 0 and trainer.global_step != 0:                                                                                                                             
            trainer.training = False                                                                                                                                                                               
            stage = trainer.state.stage                                                                                                                                                                            
            trainer.state.stage = RunningStage.VALIDATING                                                                                                                                                          
            trainer._run_evaluate()                                                                                                                                                                                
            trainer.state.stage = stage                                                                                                                                                                            
            trainer.training = True                                                                                                                                                                                
            trainer.logger_connector._epoch_end_reached = False                                                                                                                                                    
            self.last_run = trainer.global_step

Trainer Args to set everything based on Validation:

            max_steps = # Max Steps Here
            n_steps = # Number of Steps Per Val/Checkpoint/Stopping Check
            trainer = Trainer(                                                                                                                                                                                     
                ...                                                                                                                                           
                max_steps=max_steps,                                                                                                                                                                         
                max_epochs=1000,                                                                                                                                                                                   
                check_val_every_n_epoch=1000,                                                                                                                                                                      
                callbacks=[                                                                                                                                                                                        
                    ValEveryNSteps(n_steps),                                                                                                                                                                            
                    pl.callbacks.ModelCheckpoint(                                                                                                                                                                  
                        monitor="val_loss", save_on_train_epoch_end=False                                                                                                                                          
                    ),                                                                                                                                                                                             
                    pl.callbacks.EarlyStopping(                                                                                                                                                                    
                        monitor="val_loss",                                                                                                                                                                        
                        min_delta=0.00,                                                                                                                                                                            
                        patience=5,                                                                                                                                                                                
                        verbose=True,                                                                                                                                                                              
                        mode="min",                                                                                                                                                                                
                        check_on_train_epoch_end=False,                                                                                                                                                            
                    ),                                                                                                                                                                                             
                ]                                                                                                                                                                    
            ) 

@yangyi02
Copy link

yangyi02 commented Apr 4, 2022

It is pytorch-lightning 1.6.0 now. Does it support the validate every n steps now? I really want this for training and validating on a large scale dataset.

I also want to validate for m steps rather than the whole validation dataset if possible.

@carmocca
Copy link
Contributor

Support has been added with #11993

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet