Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: auto scale batch size #1638

Merged

Conversation

SkafteNicki
Copy link
Member

@SkafteNicki SkafteNicki commented Apr 27, 2020

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

What does this PR do?

Fixes #1615 and #1444.
This implements an algorithm that can automatically finds the largest possible batch size that fits in memory (no OOM). Current two modes are supported: power and binsearch. power will iteratively multiply the batch size by 2, until an OOM is encountered and stop. binsearch will further try to refine the batch size from there through a binary search strategy.

In power (default) the output in terminal currently look something like this (running the LightningTemplateModel)
Screenshot from 2020-04-27 16-58-02

The interface for this feature is currently very much like the learning rate finder introduced some time ago. In the basic case the user can set the trainer flag auto_scale_batch_size=True and the batch finder will run when .fit() is called. Similar to the learning rate finder, this assumes the user has a field in hparam field called model.hparams.batch_size that can be overridden with whatever batch size is found. If the user instead want to write to another field this can be done with auto_scale_batch_size=my_field (corresponding to model.hparams.my_field).

For the power-user, after initializing the trainer, can invoke the method scale_batch_size and thereby control the search through the methods parameters.

WIP right now as test and better documentation are missing. Also needs to figure out exactly where this should be located in the codebase: currently in TrainerTrainingTricksMixin but should maybe be its own Mixin.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Apr 27, 2020

Hello @SkafteNicki! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-05-09 08:43:53 UTC

@mergify mergify bot requested a review from a team April 27, 2020 15:22
@mergify mergify bot requested a review from a team April 27, 2020 15:30
@mergify mergify bot requested a review from a team April 27, 2020 15:31
@mergify
Copy link
Contributor

mergify bot commented May 1, 2020

This pull request is now in conflict... :(

@SkafteNicki SkafteNicki changed the title [WIP] Feature: auto scale batch size Feature: auto scale batch size May 4, 2020
@Borda Borda added the feature Is an improvement or enhancement label May 4, 2020
@Borda Borda added this to the 0.7.6 milestone May 4, 2020
@williamFalcon
Copy link
Contributor

williamFalcon commented May 5, 2020

/rebase

@williamFalcon
Copy link
Contributor

@SkafteNicki i really want to try this haha... can we merge?

@Borda Borda force-pushed the feature/auto_batch_size branch from 0a639b5 to 2d3988a Compare May 5, 2020 19:08
@SkafteNicki
Copy link
Member Author

I think it works as it should (atleast for me). I could maybe add more tests if you want. You are also welcome to checkout the branch before merge.

@williamFalcon
Copy link
Contributor

all good. let’s get all the tests to pass

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is Great addition ❤️

docs/source/training_tricks.rst Outdated Show resolved Hide resolved
docs/source/training_tricks.rst Outdated Show resolved Hide resolved
pytorch_lightning/trainer/__init__.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_tricks.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_tricks.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_tricks.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_tricks.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_tricks.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_tricks.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_tricks.py Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team May 5, 2020 19:37
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went on a typo hunt :) hope you don't mind my pedantic behaviour :)
The batch finder function is quite complex and could be further abstracted a bit. How hard would it be to add a third mode we needed to do so in the future?

docs/source/training_tricks.rst Outdated Show resolved Hide resolved
docs/source/training_tricks.rst Outdated Show resolved Hide resolved
docs/source/training_tricks.rst Outdated Show resolved Hide resolved
docs/source/training_tricks.rst Outdated Show resolved Hide resolved
pytorch_lightning/trainer/__init__.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_tricks.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_tricks.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_tricks.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_tricks.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_tricks.py Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team May 5, 2020 20:43
@mergify mergify bot requested a review from a team May 5, 2020 21:45
@SkafteNicki
Copy link
Member Author

@Borda & @williamFalcon I messed up when I tried to pull the latest changes that you made, and now the latest commit edits a lot more file than intended. Is there a way to fix this? Basically revert 1 commit. I am no git expert.

@Borda
Copy link
Member

Borda commented May 6, 2020

@Borda & @williamFalcon I messed up when I tried to pull the latest changes that you made, and now the latest commit edits a lot more file than intended. Is there a way to fix this? Basically revert 1 commit. I am no git expert.

sure, just drop the last commit :] with git rebase -i HEAD~2 or I can do it for you...

EDIT: it does not allow me to fetch your branch it crashes...

@Borda Borda force-pushed the feature/auto_batch_size branch from df02a2c to 810ecc8 Compare May 6, 2020 13:47
@Borda
Copy link
Member

Borda commented May 6, 2020

@SkafteNicki I have tried to get it back and seems to be fine now, but pls check it...

@SkafteNicki
Copy link
Member Author

On a side note, in the future I think both this feature and the learning rate finder should be redone (I made them both, so this is my own fault). I realized that both follow a pattern:

  1. dump current state of model and trainer
  2. alter some trainer args/variables to suit the feature
  3. do the feature by calling .fit() internally
  4. save results from the feature
  5. restore the initial settings

Instead of all the hassel of saving, altering and restoring it is probably a better idea to just initialize a new instance of pl.Trainer inside the feature, and copy over important settings (like device) to the new instance. Then only initial state of model (which can easily be done if issue #1619 is solved) need to be saved/restored.
This pattern is not only present in these two features, but future features also follow it. For example, I have looked a bit on incorporating a cross-validation feature into pl (issue #839) and this feature follows the exact same pattern.

@Borda
Copy link
Member

Borda commented May 8, 2020

There is a suggestion from @tullie (and I like) that we may consider splitting this hyper param tuning like batch size, learning rate to separate class/object Tuner to lower a bit the code complexity and make it more transparent...

@@ -474,7 +482,7 @@ def __init__(
self.show_progress_bar = show_progress_bar

self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar_callback = None
self.progress_bar_callback = progress_bar_callback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the arg progress_bar_callback was not used anywhere - forgotten, hope it is the right place...

Comment on lines +298 to +309
count += 1
if count > max_trials:
break
# Double in size
low = new_size
if high:
if high - low <= 1:
break
midval = (high + low) // 2
new_size = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded')
else:
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather move this to the else section as here we do not expect any failed, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is meant to check if we are still in the initial phase of doubling the batch size or we have failed once (i.e. high is defined) and thus is in the binary search phase

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I mean

try:
  Do something 
except:
  Do if the something failed 
else:
  Do others if the something pass

@mergify mergify bot requested a review from a team May 8, 2020 11:14
@mergify mergify bot requested a review from a team May 8, 2020 11:17
@Borda Borda requested review from awaelchli and justusschock May 8, 2020 11:57
@Borda Borda added the ready PRs ready to be merged label May 8, 2020
@SkafteNicki
Copy link
Member Author

There is a suggestion from @tullie (and I like) that we may consider splitting this hyper param tuning like batch size, learning rate to separate class/object Tuner to lower a bit the code complexity and make it more transparent...

I also like this. I really think that it is great that lightning has kept its interface so minimalist for most user e.g. most user only need to interact with LightningModule and Trainer. However, I also think the time has come to extend the interface for these more advance features.

@williamFalcon
Copy link
Contributor

We can explore this on @tullie's GH issue. In the meantime, let's get this merged haha.

@mergify
Copy link
Contributor

mergify bot commented May 9, 2020

This pull request is now in conflict... :(

@williamFalcon williamFalcon merged commit 4970927 into Lightning-AI:master May 9, 2020
def is_cudnn_snafu(exception):
return isinstance(exception, RuntimeError) \
and len(exception.args) == 1 \
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
Copy link
Contributor

@BlackHC BlackHC May 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR to implement toma in PyTorch Lightning!

If you copy code and ideas from my projects, could you please add a mention from it, too? I see that you're a fellow PhD student, so you are aware of the importance of credit assignment.

In particular, if you copy code verbatim and remove helpful comments... maybe add them back.

def is_cuda_out_of_memory(exception):
    return (
        isinstance(exception, RuntimeError) and len(exception.args) == 1 and "CUDA out of memory." in exception.args[0]
    )


def is_cudnn_snafu(exception):
    # For/because of https://github.com/pytorch/pytorch/issues/4107
    return (
        isinstance(exception, RuntimeError)
        and len(exception.args) == 1
        and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
    )
def gc_cuda():
    """Gargage collect Torch (CUDA) memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

is from https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py.

Now, I think this PR contains lots of other code, and I think it's great, but maybe add a link or a mention.

Thank you,
Andreas

PS: Keeping my rather random method names is a bit of a give-away.

@williamFalcon
Copy link
Contributor

williamFalcon commented May 25, 2020

@BlackHC i'm sorry, i had no idea this code was copied... @SkafteNicki generally we like to do our own implementations of features, and under no circumstance do we allow code copying.

I suggest a few things to rectify this:

  1. we use toma as is and add them as a dependency for this feature
    or
  2. we come up with our own actual implementation.

But seeing how the code was copied from the toma repo i would rather play nice and bring them in as an actual dependency.

@BlackHC my deepest apologies, i was not aware that this code came from your repo!

@PyTorchLightning/core-contributors thoughts?

@BlackHC
Copy link
Contributor

BlackHC commented May 25, 2020

The rest of the code seems quite original/I haven't reviewed it in detail. I'm sure you have a good understanding of it and its quality because it follows a slightly different approach than toma. With the binary search and the potential of using a higher batchsize than specified, it might be worth looking into ghost batchnorm in the future if this is used for training.

What I would be great is:

  • add an inspired by Andreas Kirsch's https://github.com/BlackHC/toma comment in the source/feature docs,
  • add the # For/because of https://github.com/pytorch/pytorch/issues/4107 comment back to explain why it checks for that exception (it's a bit magical otherwise); and
  • add a 'based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py' comment to utilities/memory.py?

I can grant you license for the those lines of code outside of the MIT license used in toma, so should be fine. No need to rewrite things.

Please let me what you think.

Thanks,
Andreas

@Borda
Copy link
Member

Borda commented May 25, 2020

@BlackHC I am sorry for this, I was not aware of it...
I would add it as a dependency, there is no need to develop wheel again unless we can get a better wheel :]
@SkafteNicki or @BlackHC (to be also a mentioned contributor in next release) mind send a PR with importing mentioned util functions from your lib?

@BlackHC
Copy link
Contributor

BlackHC commented May 25, 2020

Thank you very much!

From a dependency point of view, it's only the functions I mentioned (and is_out_of_cpu_memory from https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py), so I'm not sure it's worth including the full dependency at this point. Just mentioning the original source might be enough.

If toma adds lots of functionality it might be worth having another look. It does not scale batchsizes up but only down at the moment, as it follows a slightly different paradigm.

I'm really impressed by how quickly you have replied and reacted. I think it's amazing.

Thanks,
Andreas

@Borda
Copy link
Member

Borda commented May 25, 2020

@williamFalcon we may think about using some other functionalities but

which is a very limiting factor for us...
the good side is that the only dependencies are torch and psutils only

@SkafteNicki
Copy link
Member Author

I am so sorry about this. Let me try to explain. I originally tried to integrate toma into lightning (because it is a awesome library) but could not figure out how to get some functionality to work with the lightning interface, especially the binary search and hparams. I therefore ended up doing a custom implementation by myself. Everything except for the utility functions for determining when we are out of memory I wrote myself, but I am truly sorry that I in the heat of programming forgot to reference the original source code for these function. I am sorry if I have offended you @BlackHC, it was never my intent, would never do that to a follow PhD student.
@williamFalcon and @Borda I will gladly rectify my mistake in a PR by updating all the code with correct references.

@Borda
Copy link
Member

Borda commented May 25, 2020

@williamFalcon and @Borda I will gladly rectify my mistake in a PR by updating all the code with correct references.

We need to resolve the license case in dependencies then I would suggest to @BlackHC to make this simple PR to be also on the contribution list (as it is generated from PR authors)

@williamFalcon
Copy link
Contributor

williamFalcon commented May 25, 2020

@SkafteNicki no worries, i'm sure it wasn't malicious - we just want to play fair with the broader set of tools.

Let's follow @BlackHC suggestions here and make the correct PRs/adjustment to our codebase.

@BlackHC thank you for providing the rights to use! We will make sure to follow the considerations you requested.

Thanks!

@BlackHC
Copy link
Contributor

BlackHC commented May 26, 2020

Thank you very much! I'll prepare a PR shortly (probably this evening as we have an internal NeurIPS deadline before that 😬).

@SkafteNicki thanks for explaining and no worries! I think it's great you implemented it in a way that targets PyTorch Lightning specifically, and I'm glad I was able to provide inspiration with toma, and the utility functions were useful. This is a big part of why open-source is great. I hope you'll keep contributing to PyTorch Lightning!

Thanks,
Andreas

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature to automatically choose batch size
7 participants