Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning Rate finder #1347

Merged
merged 37 commits into from
Apr 10, 2020
Merged

Learning Rate finder #1347

merged 37 commits into from
Apr 10, 2020

Conversation

SkafteNicki
Copy link
Member

@SkafteNicki SkafteNicki commented Apr 2, 2020

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

What does this PR do?

Fixes #624
This PR implements a new method to the trainer class lr_finder=Trainer.find_lr(model), that are similar to the feature found in fast.ai. It does a small fit of the model, where the lr is increased after each batch and the corresponding loss is logged. The output object lr_finder can then be used to investigate the connection between choice of lr and the loss of the model. It can be used to reduce the amount of guesswork of choosing a good lr and it can be used to choose good bounds for the CyclicLRScheduler.

The interface is simple from a user-standpoint:

model = MyModelClass(hparams)
trainer = pl.Trainer()
lr_finder = trainer.find_lr(model)
# Plot results
lr_finder.plot(suggest=True)
# Choose based on plot, or get a suggestion
model.hparams.lr = lr_finder.suggestion()
# Fit
trainer.fit(model)

Running above code for pl_examples/basic_model/cpu_templatlightning_module_template.py model produces the following plot (red point corresponds to the suggested lr to use)

lr_finder

The feature seemed to gain much traction when it was proposed, however lightning was at that time missing a step-wise scheduling feature. This was implemented in PR #941, and this feature was therefore possible to implement now using more or less standard lightning features (callbacks ect.)

This PR is currently missing a lot (documentation, tests ect.) but I wanted I bit of feedback if this is still a wanted feature in lightning or if instead should be a part of lightning-bolts (when that is up and running).

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@mergify mergify bot requested a review from a team April 2, 2020 15:54
@Borda Borda added the feature Is an improvement or enhancement label Apr 2, 2020
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very excited about this feature! 🤖

pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
self.num_iter = num_iter
super(_LinearLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having it as property lr?

pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
Comment on lines 78 to 105
# Max step set to number of iterations
max_steps = self.max_steps
self.max_steps = num_iters

# Disable standard progress bar for fit
show_progress_bar = self.show_progress_bar
self.show_progress_bar = False

# Accumulation of gradients
accumulate_grad_batches = self.accumulate_grad_batches
self.accumulate_grad_batches = num_accumulation_steps

# Configure optimizer and scheduler
optimizer, _, _ = self.init_optimizers(model.configure_optimizers())
assert len(optimizer) == 1, 'cannot find lr for more than 1 optimizer'
configure_optimizers = model.configure_optimizers
model.configure_optimizers = lr_finder._get_new_optimizer(optimizer[0])

# Fit, lr & loss logged in callback
self.fit(model)

# Promt if we stopped early
if self.global_step != num_iters:
print('LR finder stopped early due to diverging loss.')

# Transfer results from callback to lr finder object
lr_finder.results.update({'lr': self.callbacks[0].lrs,
'loss': self.callbacks[0].losses})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regarding the logger suppression, it would be nice to have this as a function/method with a wrapper...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain a bit more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a single block of code which you want to execute "silently"
co make it as a separate function and write a wrapper which disables the logger and later restore

@Borda Borda changed the title Lr finder Learning Rate finder Apr 2, 2020
@Borda Borda added this to the 0.7.2 milestone Apr 2, 2020
Copy link
Contributor

@williamFalcon williamFalcon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! I'd prefer to put this into an argument in the trainer... not have an additional method the user has to call.

Trainer(auto_find_lr=True)

def auto_find_lr(self, model):

  lr_finder = self.find_lr(model)
  # Plot results

  # Choose based on plot, or get a suggestion
  print(f'suggested lr {lr_finder.suggestion()}')

  # automatically update the optimizers

@justusschock @ethanwharris thoughts?

Some caveats:

  1. This won't work with dp/ddp. We should modify for that case.
  2. what about multiple optimizers?

I recommend not merging until we hash out the API.
Ok to merge a v1 of this that doesn't support multiple optimizers or dp/ddp but we need to work those out eventually. This should also be clearly written in the docs

@SkafteNicki
Copy link
Member Author

  1. @williamFalcon can you explain why this wont work with dp or ddp? I though that if the method internally calls Trainer.fit() to do the actual work, then it should work out of the box. I have no experience with dp or ddp so I would need help to get this feature to support this.

  2. I don't think that other frameworks support this feature for more than 1 optimizer. However, I guess that it can be done using a simple grid search. The search would take num_optimizers * num_iters steps, so it would in most cases take <1000 steps to do this.

@williamFalcon
Copy link
Contributor

@SkafteNicki if you call .fit internally then it should be fine!

Overall though, take a look at my comments about collapsing this into a flag

@SkafteNicki
Copy link
Member Author

The idea of having this as a separate method is just taken directly from fastai. I am fine by collapsing this into a flag, however I think it removes the possibility for the user to interact with the results produced by learning rate finder before fitting the model.

@justusschock
Copy link
Member

justusschock commented Apr 3, 2020

@williamFalcon I wouldn't migrate it to a trainer arg.

What I'd really like is something like:

# init your Trainer:
trainer = Trainer (...)
with trainer.find_best_lr:
    trainer.fit

This is not much boilerplate, but I think we should not make to much implicit choices.

And if you would pass the optimiser there, you could also do:

with trainer.find_best_lr(optim1):
    with trainer.find_best_lr(optim2):
        trainer.fit()

Which would be equal to:

with trainer.find_best_lr(optim1, optim2):
    trainer.fit()

Not sure, how realistic this is, but that's the API, I'd like the most...

@Borda
Copy link
Member

Borda commented Apr 3, 2020

I recommend not merging until we hash out the API.
Ok to merge a v1 of this that doesn't support multiple optimizers or dp/ddp but we need to work those out eventually. This should also be clearly written in the docs

I would then rather target the next release and have it in v0.8.0 with metrics

@mergify
Copy link
Contributor

mergify bot commented Apr 3, 2020

This pull request is now in conflict... :(

@williamFalcon williamFalcon modified the milestones: 0.7.2, 0.7.3 Apr 3, 2020
Copy link
Contributor

@jeremyjordan jeremyjordan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great contribution! excited about this feature, just have a few comments to address.

docs/source/lr_finder.rst Outdated Show resolved Hide resolved
docs/source/lr_finder.rst Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Show resolved Hide resolved
@jeremyjordan
Copy link
Contributor

yeah i agree with @justusschock i wouldn't use a Trainer arg here. i personally like the recommended usage as it stands.

(copied from @SkafteNicki's post above )

model = MyModelClass(hparams)
trainer = pl.Trainer()
lr_finder = trainer.find_lr(model)

# Plot results
lr_finder.plot(suggest=True)

# Choose based on plot, or get a suggestion
model.hparams.lr = lr_finder.suggestion()

# Fit
trainer.fit(model)

users can name their learning rate hparam however they please (lr, learning_rate, etc.) so it would be difficult to automatically set this value. sure, we can update the actual lr for the optimizer but that wouldn't be reflected in the hparams that are logged.

Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an amazing feature!
I have some suggestions for the docs :)

docs/source/lr_finder.rst Outdated Show resolved Hide resolved
docs/source/lr_finder.rst Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/lr_finder.py Outdated Show resolved Hide resolved
@lkhphuc
Copy link

lkhphuc commented Apr 4, 2020

Great PR, I'm also looking for this feature. Thanks everyone.

yeah i agree with @justusschock i wouldn't use a Trainer arg here. i personally like the recommended usage as it stands.

I think the current usage is great to manipulate lr finder programmatically, but it would also be nicer if we have a flag for interactive training. Something like this:

$ python train.py --find_lr=True
[INFO] ....
[INFO] (Plot learning rate finder, inplace for notebook, pop up for terminal)
The suggested learning rate is 3e-4, press enter to accept or input a different value:  >>ENTER<<
[INFO] Learning rate is 3e-4.
[INFO] ....

or

$ python train.py --find_lr=True
[INFO] ....
[INFO] (Plot learning rate finder, inplace for notebook, pop up for terminal)
The suggested learning rate is 3e-4, press enter to accept or input a different value:  3e-5 >>ENTER<<
[INFO] Learning rate is 3e-5.
[INFO] ....

This flag will take precedence over all other lr related flag and the default will of course be False.

@williamFalcon
Copy link
Contributor

williamFalcon commented Apr 4, 2020

good ideas. i think we don’t want to set a trend for breaking out of the current API patterns. the trainer flags are there so that users don’t have to think about all the tiny nuances of doing something.

the approach auggested by @justusschock means the user now has to learn more API and has to remember to do a bunch of things... ie: they need to go read docs, also start pulling out optimizers, this completely breaks the lightning organization and abstraction and doesn’t go with the principle of not having to make users think about things they don’t need to think about.

the user just wants the best LR, they shouldn’t have to remember to add a with, or pull out optimizers, etc... they should set a flag and get the LR.

with the trainer flag the user doesn’t have to think about it. in fact, the library can just automatically set the best LR using what it finds... ie: it just works. we can still support the graph approach in this case by showing the plot in the logger. then the user can decide to fix the LR once they feel confident.

the approach i suggest with a flag works as follows:

  1. set the flag.
  2. the LR is found automatically.
  3. we print a nice message with the LR.
  4. the LR is set automatically and training continues
  5. the curve is also logged to the logger (if there is one).
  6. if the user wants to inspect the log and pick a LR manually, they can do that.
  7. at this point the user would just likely set the LR manually going forward from what printed or what is shown in the plot.

this approach has the advantages that:

  1. the user does not have to remember overhead of how to do this... this is a CORE value of lightning. if we lose sight of this we end up with another framework where you build up a lot of cognitive overhead to remember how to do things. this is an engineering decision which should be automated.
  2. we still get the plot which the user can interact with.
  3. the user doesn’t have to comment or delete code which they would have to do with the other approaches suggested... this will clutter code really quickly.

So, i’m going to strongly suggest we use a trainer flag instead.

@justusschock
Copy link
Member

Okay, gut can we pass an object where the best of should be stored (e.g. hparams.lr) ? If you have multiple runs you probably don't want to do the LR search every time

@williamFalcon
Copy link
Contributor

williamFalcon commented Apr 4, 2020

I guess i assume the flow would be:

  1. enable flag
  2. it prints the best lr
  3. follow on runs, disable flag and manually fix the LR going forward

what would that object do?

@jeremyjordan
Copy link
Contributor

the library can just automatically set the best LR using what it finds... ie: it just works

let me elaborate on why i think this is challenging

suppose user A has a model like:

class LitModel(pl.LightningModule):

    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True,
                          transform=transforms.ToTensor()), batch_size=32)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

and user B has a model like:

class LitModel(pl.LightningModule):

    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True,
                          transform=transforms.ToTensor()), batch_size=32)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

how do we automatically set the learning rate?

sure, we can update the pytorch optimizer'slr but i don't see how we could reliably update the model.hparams object to include the new value automatically. this is important because we're logging the hparams so that users can reproduce their results from previous experiments.

i would agree with the general flow:

  • enable learning rate feature
  • report the best learning rate
  • user sets this hyperparameter and does a full training run

i'm still not sure a flag is the best design for this. by keeping this functionality as a method, we're still only asking the user to call essentially one line of code trainer.find_lr(model). this is as simple as calling trainer.fit(model) for actual training. thus, the cognitive burden on the user is equivalent (remembering a flag name vs method name) and typing completion in editors like VS code will help in both cases.

minimal example

# find learning rate
model = MyModelClass(hparams)
trainer = pl.Trainer()
trainer.find_lr(model)

# do training run
hparams = {**hparms, 'lr': lr_finder.suggestion()}
model = MyModelClass(hparams)
trainer.fit(model)

or the user can reach into the existing model and update an hparam since we call configure_optimizers at the beginning of a trainer.fit()

# find learning rate
model = MyModelClass(hparams)
trainer = pl.Trainer()
trainer.find_lr(model)

# do training run
model.hparams.lr = lr_finder.suggestion()
trainer.fit(model)

power user

# find learning rate
model = MyModelClass(hparams)
trainer = pl.Trainer()
lr_finder = trainer.find_lr(model)       # user saves the returned results
lr_finder.plot(suggest=True)             # and can further inspect if they desire

# do training run
model.hparams.lr = lr_finder.suggestion()
trainer.fit(model)

@Borda Borda added the discussion In a discussion stage label Apr 4, 2020
@mergify mergify bot requested a review from a team April 9, 2020 16:44
Copy link
Contributor

@jeremyjordan jeremyjordan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work on this! ⚡

@mergify mergify bot requested a review from a team April 10, 2020 02:22
@mergify mergify bot requested a review from a team April 10, 2020 15:59
@williamFalcon
Copy link
Contributor

@SkafteNicki this is an awesome feature!

@mergify
Copy link
Contributor

mergify bot commented Apr 10, 2020

This pull request is now in conflict... :(


lr_max: lr to stop seach

num_training: number of steps to take between lr_min and lr_max
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather num_train_steps

self.num_iter = num_iter
super(_ExponentialLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it shall be described what is the doff between this get_lr and just lr bellow because intuitively (by the name) they shall return the same

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_lr() is the method called inside lr_scheduler.step() and is not meant to be called elsewhere. Since pytorch 1.4 the property self._last_lr was introduced to extract the last computed lr. However, since pytorch-ligning need to be backwards compatible, I created the self.lr property that archives the same. They therfore have slightly different purpose.

@williamFalcon williamFalcon merged commit 3f09b32 into Lightning-AI:master Apr 10, 2020
@SkafteNicki SkafteNicki deleted the lr_finder branch April 21, 2020 13:51
tullie pushed a commit to tullie/pytorch-lightning that referenced this pull request Jun 7, 2020
* initial structure

* rebase

* incorporate suggestions

* update CHANGELOG.md

* initial docs

* fixes based on reviews

* added trainer arg

* update docs

* added saving/restore of model state

* initial tests

* fix styling

* added more tests

* fix docs, backward compatility and progressbar

* fix styling

* docs update

* updates based on review

* changed saving to standard functions

* consistent naming

* fix formatting

* improve docs, added support for nested fields, improve codecov

* update CHANGELOG.md

* Update lr_finder.rst

* Update pytorch_lightning/trainer/trainer.py

* Update trainer.py

* Update CHANGELOG.md

* Update path

* restoring

* test

* attribs

* docs

* doc typo

Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: William Falcon <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: J. Borovec <[email protected]>
@Borda Borda modified the milestones: 0.7.4, v0.7.x Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cyclic learning rate finder as a part of Trainer
8 participants