-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
long-standing Bug in Adafactor optimizer if beta1 > 0 #34506
Comments
This sounds like a bug. Amazing discovery. |
Hey - we're aware of this report, and it's extremely interesting if it checks out, but it'll take us a little bit to sit down with the papers and validate everything! |
@dxqbYD we checked it out and we think you're right that the implementation for Would you be willing to open a PR adding a kwarg to enable the new behaviour? We can default to the old behaviour for backward compatibility for now, but if the improvements seem solid then we might switch the default to your implementation. |
Thank you for taking the time to look at this! I think it's going to be difficult to show that there are improvements. Someone would have to do paper-level research that some validation value improves, and even that is open to bias. I'd honestly just fix it, otherwise you end up with a "temporary" kwarg that's really forever and nobody uses. I'd make the argument for fixing it, because it is quite clear that the Adafactor authors wanted to propose "Adam plus", not change Adam. |
@dxqbYD The problem there is we're in slightly uncharted territory anyway! Even though your theoretical arguments are sound, we can't rule out that performance will be worse with the fix, because optimizers can be counterintuitive sometimes. At the very least we'd like some hints from the community that Adafactor is not noticeably worse with the change before we make it the default, and we do have a preference for backward compatibility unless there's a strong reason to break things. I think if we add it as a kwarg like |
Ok. Will open a PR. |
I don't think changing the default behaviour is appropriate as this matches tensor2tensor and I believe also popular jax (optax / flax) impl. So would need a kwarg if implemented. That said, I don't believe this is a bug. While not being addressed in the paper, it matches t2t impl and which was done with feedback from paper authors as per paper acknowledgements. Also, this behaviour matches RMSProp which also includes the rms scaled update in the momentum buffer when enabled. So it's not 'out of the blue' or an anomaly. Also worth noting, was another proposed variant of Adafactor in Scaling Vision Transformers (https://arxiv.org/abs/2106.04560). There is an impl in big vision https://github.com/google-research/big_vision/blob/main/big_vision/optax.py I'm sure the authors involved in the above variant were principled in their modifications and comparisons. They always have first moment enabled in this impl, though in half precision. They change things a bit in that LR is applied after the first moment (interestingly the rmsprop in TF vs PyTorch differed in this way too), but the momentum buffer here is also accumulating the updates after the rms scaling was applied, not separately as in the adam update rule. I think an impl of this 'BigVisionAdafactor' would be more interesting personally :) I might look at that sometime. EDIT: I don't make the calls here so not my decision re to add support for the proposed changes or not, these are my opinions and comments :) |
changing the default behaviour is off the table because the maintainers want to be careful, which I understand. Yes, t2t also has it.
you could argue whether RMS scaling should go into exp_avg or not, but this might have been intentional and not want I am argueing above. RMS scaling is here: transformers/src/transformers/optimization.py Line 894 in 8a734ea
What should not go into exp_avg is here: transformers/src/transformers/optimization.py Line 892 in 8a734ea
|
My brain is fried at the end of the week and tracing through optimizer code in diff frameworks is always fun, initially I though the big vision impl was doing this differently but after another pass only the place where LR scaling is done appears fundamentally diff.... When I said 'scale by rms' I was refering to optax naming for the factored scaling ('scale_by_factored_rms'), the second moment & related calculations. I was not refering to the 'clipping by block rms' operation or 'scale by block rms' there. In both cases below the momentum buffer is doing a running average of the updates (grad * exp_avg_sq.rsqrt()). In big vision, the LR is applied in after momentum instead of before. Weight decay is also a bit different wrt scaling by LR. The optax chain for big vision is:
The optax chain for the optax impl of adafactor, which I believe matches the transformers / tensor2tensor impl:
And again, going back to 'what should not go into exp_avg', rmsprop is doing something similar. Instead of maintaing grad^2 and grad moments separately and then applying them, it calculates grad^2 moment, and then calculates the update from this grad/sqrt(grad^2 buf) and then smooths this update if first moment is enabled. I think more than a 'it's not like adam' argument is needed to show it's actually a bug and not intentional. That and some strong empircal evidence to show the alternate formulation works better in some scenarios. |
Reading with fresh eyes, I believe my comments made sense. In the optax impl, after the 'scale_by_factored_rms' the subsequent ops in the chain are operating on the 'update' aka subtrahend aka value to be subtracted from the params. As with the tensor2tensor and impl here. The proposed changes may certainly be better but I'd be looking for strong evidence of that before taking on the extra maintenance burden of an alternate code path. It's unlikely there was much testing of beta1 > 0 in the original, but we do know the big vision impl was tested for some large vit training, w/ beta1 > 0 and it worked well for them. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I'll submit that PR, just didn't have time yet |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
all known
Who can help?
@muellerz @SunMarc
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
There seems to be an issue with the Adafactor optimizer found here, if beta1 is > 0:
transformers/src/transformers/optimization.py
Line 672 in 8a734ea
if beta1 is > 0, use_first_moment == True.
But the code does not actually use the first moment (only). In the following line, it adds to
exp_avg
not only the gradients, but a prepared mixture of gradients, learning rate and in inverse square root of the second moment (called 'update'):transformers/src/transformers/optimization.py
Line 899 in 8a734ea
There is nothing in the Adafactor paper that would indicate that this is intentional:
https://arxiv.org/pdf/1804.04235
The Adafactor paper focusses on beta1=0, as their proposal is to save vram by using no first moment (see section 4) and only a factorized second moment (see section 3). For this use case, beta1 is 0 and the
update
variable mentioned above is prepared correctly, but not if beta1 is > 0.In algorithm 1 on page 2 of the paper, they start out with the Adam algorithm, which updates the first momentum just with the gradients. Nowhere in the paper is it proposed to change that.
Consider this pseudocode, leaving out the betas:
AdamW:
exp_avg_sq+=grad^2
exp_avg+=grad
p+= -lr * exp_avg / sqrt(exp_avg_sq)
Adafactor with
use_first_moment == True
:exp_avg_sq+=grad^2 #or its factored approximation
exp_avg+=lr * grad / sqrt(exp_avg_sq)
p+=-exp_avg
Note that
grad
andexp_avg_sq
are already added toexp_avg
, not only to the parameters. It'll diverge from Adam starting the 2nd step.I am reporting this issue to transformers, because this seems to be the current reference implementation, with many projects either using or directly copying from this repo.
But this issue didn't originate here.
It can even be found in the tensor2tensor code that is mentioned by the authors in the paper:
https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/adafactor.py#L265
If they have used this code for their experiments, they might have underestimated the difference between using the first moment and not using it in their experiments on page 7, because using the first moment wasn't working correctly.
The impact of this issue today might be limited, because when people use Adafactor, they usually do it for its vram savings, and set beta1 to 0. It might still confuse people though, into thinking that first moment is less important than it is, if they make comparisons between beta1 == 0 and beta1 > 0 both using Adafactor.
And someone might actually use it with beta1 > 0 for its memory savings by factorizing the second moment, and expect that the first moment works like Adam.
Expected behavior
update of
exp_avg
as implemented in AdamThe text was updated successfully, but these errors were encountered: