-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix LR schedule handling for low-bit optimizers #736
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/736
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cdf283c with merge base 68e4643 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
copy paste is mostly fine here to me (unless it really bothers you) even pytorch core has lots of copy paste between adam and adamw |
@janeyx99 Since CPU scalar Tensor works with n-D CUDA Tensor, should But incrementing UPDATE: nvm, I saw that @msaroufim It's just annoying when I make changes to the wrong file and need to be careful with the name - copy paste but still need to replace adam with adamw or vice versa 😅 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Direction is certainly what was discussed, so nice! Change looks mostly good to me.
- Regarding merging adam(w) together, there’s actually been an action item to merge the impls and make AdamW a subclass, so I understand exactly what you mean 🥲. That said, this change can def go in before you address that as this perf gap is more crucial to close.
- Yes, step on CPU could be more performant (we normally have our steps be CPU tensors in core optims)
- Logging/calling .item() esp on a CUDA tensor is expensive! just fyi
@@ -20,7 +20,7 @@ def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) | |||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |||
if not 0.0 <= betas[1] < 1.0: | |||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) | |||
defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the right place to set it? i would advocate for making this very obvious with users what is going on, and i would modify the AdamW constructor below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unless you’re adamant that this optimizer will only support Tensor lrs. In pytorch/pytorch, we support python float lrs in eager because it is faster to compute python math than launch kernels, though that may be less relevant here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative I’d suggest for max visibility from user is changing line 165 below to be torch.tensor(1e-3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think to keep it simple, we just enforce lr to be a scalar tensor here. If lr is not changed during training, whether it is a tensor or not would not matter. But if it is changed during training, we need lr to be a tensor anyway since torch.compile will recompile when python float lr changes value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, but I still think it’d be most visible/clear to users if the constructors of Adam and AdamW clearly set the default as Tensors and if this base would just error if lr was not a Tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, do you mean that forcing users to pass in LR as Tensor instead of Python float? i.e. error is raised if LR is not a Tensor.
And is it your view that converting Python float to Tensor in the constructor might seem unexpected to the users?
Personally when using optimizers, I never pass in Tensor LR before, so it feels strange to me 😅 (doesn't mean I'm correct, just a feeling from my limited experience). I think that converting LR from float to Tensor inside the constructor is an implementation detail that the users shouldn't need to care about.
Also, most (if not all?) other optimizers will work if I pass in a Python float LR? So feel kinda strange (again 🤣) if users are forced to pass in Tensor LR to this particular optimizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it'd be good to have users pass in Tensor lrs, so they're most aware of what is going on. I think it is not great to switch it up under the hood and have the user be confused if there's ever an error regarding the Tensorness of the lr.
(I will update the benchmarks number in README. don't merge yet) |
PyTorch (fused) | 12.23 | 41.8 | 94.38 | ||
bnb 8-bit | 8.32 | 43.6 | 94.18 | ||
ao 8-bit | 8.33 | 42.6 | 94.25 | ||
ao FP8 E4M3 | 9.27 | 44.1 | 94.40 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP8 max memory is ~1GB higher than expected. I re-ran the benchmark on main branch (without this PR) and the max memory for FP8 is the same. I'm suspecting something funny happening with torch.compile. The benchmark was done with 2.5.0.dev20240820
. I don't think it's a big issue, especially since FP8 optimizer is not popular yet (may change in the future though 👀). Re-running with 2.4 now. (probably won't re-run the rest with 2.4 since I'm lazy)
(Accuracy is much better than before across the board thanks to cosine LR scheduler)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP8 optimizer
PyTorch version | max mem (GB) | imgs/s | acc |
---|---|---|---|
2.4.0 | 9.04 | 42.5 | 94.18 |
2.5.0.dev20240820 | 9.27 | 44.1 | 94.40 |
Definitely something funny with newer torch.compile 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On 2.3.1, Triton refused to run due to compute capability (though my GPU should support - 2.4.0 and nightly are fine)
Conversion from/to f8e4m3nv is only supported on compute capability >= 90Conversion from/to f8e4m3nv is only supported on compute capability >= 90
UNREACHABLE executed at ../../../lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp:823!
In the original benchmark numbers, FP8 optim was good (max memory is the same as 8-bit optim). But I don't rmb which PyTorch version I used back then 😅.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just compiling the optimizer. There were some larger changes to the min cut partioner that affect how we split forward + bwd graphs ( determine what to recompute vs save). These changes show up more in fp8 since you typically have long chains of ops to dequant or quant. But if there is no fwd/bwd I am not totally sure what might be happening
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.compile fishiness warrants asking someone who may know more cc @drisspg are you familiar with any changes from 2.4 to now that’d affect compile on fp8 e3m4?
otherwise i think this PR is okay to go (though i think Tensor lr should still be more explicit).
Fixes #730
Tested with ViT-Base (~100M) params. Excellent speed up 😄. The "before" is much slower than my previously reported numbers (3-4%) is probably because I use a much smaller model -> allocating CUDA LR tensor is much more evident.
Marking this as draft right now because I want to consolidate Adam and AdamW into a single file. It's becoming a bit troublesome to make the same edits in both files whenever I need to change anything 😅.
cc @janeyx99 @msaroufim