-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: auto scale batch size #1638
Feature: auto scale batch size #1638
Conversation
Hello @SkafteNicki! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-05-09 08:43:53 UTC |
This pull request is now in conflict... :( |
/rebase |
@SkafteNicki i really want to try this haha... can we merge? |
0a639b5
to
2d3988a
Compare
I think it works as it should (atleast for me). I could maybe add more tests if you want. You are also welcome to checkout the branch before merge. |
all good. let’s get all the tests to pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is Great addition ❤️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
went on a typo hunt :) hope you don't mind my pedantic behaviour :)
The batch finder function is quite complex and could be further abstracted a bit. How hard would it be to add a third mode we needed to do so in the future?
@Borda & @williamFalcon I messed up when I tried to pull the latest changes that you made, and now the latest commit edits a lot more file than intended. Is there a way to fix this? Basically revert 1 commit. I am no git expert. |
sure, just drop the last commit :] with EDIT: it does not allow me to fetch your branch it crashes... |
df02a2c
to
810ecc8
Compare
@SkafteNicki I have tried to get it back and seems to be fine now, but pls check it... |
On a side note, in the future I think both this feature and the learning rate finder should be redone (I made them both, so this is my own fault). I realized that both follow a pattern:
Instead of all the hassel of saving, altering and restoring it is probably a better idea to just initialize a new instance of |
There is a suggestion from @tullie (and I like) that we may consider splitting this hyper param tuning like batch size, learning rate to separate class/object Tuner to lower a bit the code complexity and make it more transparent... |
@@ -474,7 +482,7 @@ def __init__( | |||
self.show_progress_bar = show_progress_bar | |||
|
|||
self.progress_bar_refresh_rate = progress_bar_refresh_rate | |||
self.progress_bar_callback = None | |||
self.progress_bar_callback = progress_bar_callback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the arg progress_bar_callback
was not used anywhere - forgotten, hope it is the right place...
count += 1 | ||
if count > max_trials: | ||
break | ||
# Double in size | ||
low = new_size | ||
if high: | ||
if high - low <= 1: | ||
break | ||
midval = (high + low) // 2 | ||
new_size = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded') | ||
else: | ||
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather move this to the else section as here we do not expect any failed, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is meant to check if we are still in the initial phase of doubling the batch size or we have failed once (i.e. high is defined) and thus is in the binary search phase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I mean
try:
Do something
except:
Do if the something failed
else:
Do others if the something pass
I also like this. I really think that it is great that lightning has kept its interface so minimalist for most user e.g. most user only need to interact with |
We can explore this on @tullie's GH issue. In the meantime, let's get this merged haha. |
This pull request is now in conflict... :( |
def is_cudnn_snafu(exception): | ||
return isinstance(exception, RuntimeError) \ | ||
and len(exception.args) == 1 \ | ||
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR to implement toma in PyTorch Lightning!
If you copy code and ideas from my projects, could you please add a mention from it, too? I see that you're a fellow PhD student, so you are aware of the importance of credit assignment.
In particular, if you copy code verbatim and remove helpful comments... maybe add them back.
def is_cuda_out_of_memory(exception):
return (
isinstance(exception, RuntimeError) and len(exception.args) == 1 and "CUDA out of memory." in exception.args[0]
)
def is_cudnn_snafu(exception):
# For/because of https://github.com/pytorch/pytorch/issues/4107
return (
isinstance(exception, RuntimeError)
and len(exception.args) == 1
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
)
def gc_cuda():
"""Gargage collect Torch (CUDA) memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
is from https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py.
Now, I think this PR contains lots of other code, and I think it's great, but maybe add a link or a mention.
Thank you,
Andreas
PS: Keeping my rather random method names is a bit of a give-away.
@BlackHC i'm sorry, i had no idea this code was copied... @SkafteNicki generally we like to do our own implementations of features, and under no circumstance do we allow code copying. I suggest a few things to rectify this:
But seeing how the code was copied from the toma repo i would rather play nice and bring them in as an actual dependency. @BlackHC my deepest apologies, i was not aware that this code came from your repo! @PyTorchLightning/core-contributors thoughts? |
The rest of the code seems quite original/I haven't reviewed it in detail. I'm sure you have a good understanding of it and its quality because it follows a slightly different approach than toma. With the binary search and the potential of using a higher batchsize than specified, it might be worth looking into ghost batchnorm in the future if this is used for training. What I would be great is:
I can grant you license for the those lines of code outside of the MIT license used in toma, so should be fine. No need to rewrite things. Please let me what you think. Thanks, |
@BlackHC I am sorry for this, I was not aware of it... |
Thank you very much! From a dependency point of view, it's only the functions I mentioned (and If toma adds lots of functionality it might be worth having another look. It does not scale batchsizes up but only down at the moment, as it follows a slightly different paradigm. I'm really impressed by how quickly you have replied and reacted. I think it's amazing. Thanks, |
@williamFalcon we may think about using some other functionalities but
which is a very limiting factor for us... |
I am so sorry about this. Let me try to explain. I originally tried to integrate toma into lightning (because it is a awesome library) but could not figure out how to get some functionality to work with the lightning interface, especially the binary search and hparams. I therefore ended up doing a custom implementation by myself. Everything except for the utility functions for determining when we are out of memory I wrote myself, but I am truly sorry that I in the heat of programming forgot to reference the original source code for these function. I am sorry if I have offended you @BlackHC, it was never my intent, would never do that to a follow PhD student. |
We need to resolve the license case in dependencies then I would suggest to @BlackHC to make this simple PR to be also on the contribution list (as it is generated from PR authors) |
@SkafteNicki no worries, i'm sure it wasn't malicious - we just want to play fair with the broader set of tools. Let's follow @BlackHC suggestions here and make the correct PRs/adjustment to our codebase. @BlackHC thank you for providing the rights to use! We will make sure to follow the considerations you requested. Thanks! |
Thank you very much! I'll prepare a PR shortly (probably this evening as we have an internal NeurIPS deadline before that 😬). @SkafteNicki thanks for explaining and no worries! I think it's great you implemented it in a way that targets PyTorch Lightning specifically, and I'm glad I was able to provide inspiration with toma, and the utility functions were useful. This is a big part of why open-source is great. I hope you'll keep contributing to PyTorch Lightning! Thanks, |
Before submitting
What does this PR do?
Fixes #1615 and #1444.
This implements an algorithm that can automatically finds the largest possible batch size that fits in memory (no OOM). Current two modes are supported:
power
andbinsearch
.power
will iteratively multiply the batch size by 2, until an OOM is encountered and stop.binsearch
will further try to refine the batch size from there through a binary search strategy.In
power
(default) the output in terminal currently look something like this (running theLightningTemplateModel
)The interface for this feature is currently very much like the learning rate finder introduced some time ago. In the basic case the user can set the trainer flag
auto_scale_batch_size=True
and the batch finder will run when.fit()
is called. Similar to the learning rate finder, this assumes the user has a field in hparam field calledmodel.hparams.batch_size
that can be overridden with whatever batch size is found. If the user instead want to write to another field this can be done withauto_scale_batch_size=my_field
(corresponding tomodel.hparams.my_field
).For the power-user, after initializing the trainer, can invoke the method
scale_batch_size
and thereby control the search through the methods parameters.WIP right now as test and better documentation are missing. Also needs to figure out exactly where this should be located in the codebase: currently in
TrainerTrainingTricksMixin
but should maybe be its ownMixin
.PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃