Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Refactor validation logic API to naturally support step-based training duration #12000

Closed
nikvaessen opened this issue Feb 19, 2022 · 9 comments
Labels
design Includes a design discussion refactor trainer: argument won't fix This will not be worked on

Comments

@nikvaessen
Copy link
Contributor

nikvaessen commented Feb 19, 2022

Proposed refactor

Currently, users can control when validation happens with the following arguments in the Trainer object:

Trainer(
  val_check_interval: Union[int, float] = float(1.0),
  check_val_every_n_epoch: int = int(1)
)

I propose to refactor the API to be similar to the learning rate schedule config:

Trainer(
  validation_interval: "step" | "epoch" = "epoch",
  validation_frequency: Union[int, float] = int(1)
)

This has 4 modes of operations:

  1. validation_interval=step, validation_frequency=int(v), requires max_steps=int(N): We train for a maximum of N steps and validate every v steps, with v << N and v > 0, N > 0. For example, v=100, N=1000 would validate at step 100, 200, ..., 1000.

  2. validation_interval=step, validation_frequency=float(v), requires max_steps=int(N): We train for a maximum of N steps and validate every N*v steps, with 0 < v < 1 and N > 0. For example, v=0.1, N=1000 would validate at step 100, 200, ..., 1000.

  3. validation_interval=epoch, validation_frequency=int(v), requires max_epochs=int(N): We train for a maximum of N epochs and validate every vth epoch, with v << N, and v > 0, N > 0. For example, v=2, N=100 would validate at the end of epoch 2, 4, ..., 100.

  4. validation_interval=epoch, validation_frequency=float(v), requires max_epochs=int(N): We train for a maximum of N epochs and validate multiple times during each epoch, according to v*len(dataloader), with 0 < v < 1, and N > 0. For example, v=0.5, N=100 would validate twice each epoch, once in the middle, and once at the end.

We would default to validation_interval=epoch and validation_frequency=1 to mimic the current default behavior.

Motivation

The current API was designed with an "epoch" based training mindset common in the earlier days of deep learning. With the growth of datasets size, and the popularity of cyclic, or staged, learning rate schedules, it has become more common to talk about training length in terms of steps instead of epochs . However, choosing a correct validation schedule when using max_steps=N with PyTorch Lightning is not intuitive, and users must still be aware of the number of steps in each epoch in order to not make mistakes.

Pitch

If we change to this proposed API, all current use-cases will still be supported, but it will become a lot easier to reason about when validation will happen.

Additional context

A stop-gap solution to support the behavior of validation_interval=step, validation_frequency=int(v) with the current API provided in #11993

cc @justusschock @awaelchli @rohitgr7 @tchaton @Borda @kaushikb11

@yangyi02
Copy link

yangyi02 commented Apr 4, 2022

Really want this functionality! Thanks for proposing that!

There is also a similar thread here: #2534

@carmocca
Copy link
Contributor

carmocca commented Apr 6, 2022

Have you had a look at the proposal in #8135 (comment)? It has the advantage of not introducing new arguments

@nikvaessen
Copy link
Contributor Author

nikvaessen commented Apr 6, 2022

Yes, PR #11993 implements the first option you mention in #8135 (comment). But I believe that the PR/solution doesn't adhere to the design principles "Simple internal code" and "Simple external API", therefore I also proposed this RFC.

@carmocca carmocca added the design Includes a design discussion label Apr 6, 2022
@carmocca
Copy link
Contributor

carmocca commented Apr 6, 2022

Thanks for your critical thinking! Lets cc @PyTorchLightning/core-lightning for thoughts on

a. relying on the value types to differentiate between options as described in #8135 (comment)
b. deprecate the existing API and add new explicit flags as proposed in the top post.

If (b) is chosen, I would suggest keeping val_check_interval to match what you wrote as validation_frequency and instead find a less ambiguous name than validation_interval: val_granularity? val_cadence?

@nikvaessen
Copy link
Contributor Author

nikvaessen commented Apr 6, 2022

If (b) is chosen, I think the naming should be consistent with the learning rate schedule configuration: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Currently, there interval is used to choose between epoch and step, and frequency is used to indicate the amount of steps/epochs. I do prefer granularity above interval, as I feel frequency and interval have similar meanings.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Apr 6, 2022

what will be the configuration for check_val_every_n_epochs=4 and val_check_interval = 0.25 i.e do validation 4 times at every 4th training epoch?

validation_interval=step, validation_frequency=float(v), requires max_steps=int(N): We train for a maximum of N steps and validate every N*v steps

for this, I think it can be configured easily just with validation_frequency=N*v since the user already knows the value for N.

I think we can just extend val_check_interval (possibly rename this) to support int values which can operate over the number of training batches processed overall.

Something like:

  1. if val_check_interval is float, consider the number of training batches per epoch
  2. if val_check_interval is int, consider overall training batches already processed

@carmocca
Copy link
Contributor

carmocca commented Apr 7, 2022

what will be the configuration for check_val_every_n_epochs=4 and val_check_interval = 0.25 i.e do validation 4 times at every 4th training epoch?

Yes

I think we can just extend val_check_interval (possibly rename this) to support int values which can operate over the number of training batches processed overall.

That's what I propose in #8135 (comment) by setting check_val_every_n_epochs=None (default is 1)

@stale
Copy link

stale bot commented Jun 6, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jun 6, 2022
@carmocca
Copy link
Contributor

carmocca commented Jun 6, 2022

Closed by #11993

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion refactor trainer: argument won't fix This will not be worked on
Projects
No open projects
Status: No status
Development

No branches or pull requests

5 participants