Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

long-standing Bug in Adafactor optimizer if beta1 > 0 #34506

Closed
1 of 4 tasks
dxqbYD opened this issue Oct 30, 2024 · 13 comments
Closed
1 of 4 tasks

long-standing Bug in Adafactor optimizer if beta1 > 0 #34506

dxqbYD opened this issue Oct 30, 2024 · 13 comments
Labels

Comments

@dxqbYD
Copy link

dxqbYD commented Oct 30, 2024

System Info

all known

Who can help?

@muellerz @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

There seems to be an issue with the Adafactor optimizer found here, if beta1 is > 0:

class Adafactor(Optimizer):

if beta1 is > 0, use_first_moment == True.
But the code does not actually use the first moment (only). In the following line, it adds to exp_avg not only the gradients, but a prepared mixture of gradients, learning rate and in inverse square root of the second moment (called 'update'):

exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))

There is nothing in the Adafactor paper that would indicate that this is intentional:
https://arxiv.org/pdf/1804.04235

The Adafactor paper focusses on beta1=0, as their proposal is to save vram by using no first moment (see section 4) and only a factorized second moment (see section 3). For this use case, beta1 is 0 and the update variable mentioned above is prepared correctly, but not if beta1 is > 0.

In algorithm 1 on page 2 of the paper, they start out with the Adam algorithm, which updates the first momentum just with the gradients. Nowhere in the paper is it proposed to change that.

Consider this pseudocode, leaving out the betas:
AdamW:
exp_avg_sq+=grad^2
exp_avg+=grad
p+= -lr * exp_avg / sqrt(exp_avg_sq)

Adafactor with use_first_moment == True:
exp_avg_sq+=grad^2 #or its factored approximation
exp_avg+=lr * grad / sqrt(exp_avg_sq)
p+=-exp_avg

Note that grad and exp_avg_sq are already added to exp_avg, not only to the parameters. It'll diverge from Adam starting the 2nd step.

I am reporting this issue to transformers, because this seems to be the current reference implementation, with many projects either using or directly copying from this repo.

But this issue didn't originate here.
It can even be found in the tensor2tensor code that is mentioned by the authors in the paper:
https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/adafactor.py#L265

If they have used this code for their experiments, they might have underestimated the difference between using the first moment and not using it in their experiments on page 7, because using the first moment wasn't working correctly.

The impact of this issue today might be limited, because when people use Adafactor, they usually do it for its vram savings, and set beta1 to 0. It might still confuse people though, into thinking that first moment is less important than it is, if they make comparisons between beta1 == 0 and beta1 > 0 both using Adafactor.
And someone might actually use it with beta1 > 0 for its memory savings by factorizing the second moment, and expect that the first moment works like Adam.

Expected behavior

update of exp_avg as implemented in Adam

@dxqbYD dxqbYD added the bug label Oct 30, 2024
@dxqbYD dxqbYD changed the title long-stand Bug in Adafactor optimizer if beta1 > 0 long-standing Bug in Adafactor optimizer if beta1 > 0 Oct 30, 2024
@Arcitec
Copy link

Arcitec commented Oct 31, 2024

This sounds like a bug. Amazing discovery.

@Rocketknight1
Copy link
Member

Hey - we're aware of this report, and it's extremely interesting if it checks out, but it'll take us a little bit to sit down with the papers and validate everything!

@Rocketknight1
Copy link
Member

@dxqbYD we checked it out and we think you're right that the implementation for beta > 0 is suboptimal, although it's almost certainly the one used by the original paper.

Would you be willing to open a PR adding a kwarg to enable the new behaviour? We can default to the old behaviour for backward compatibility for now, but if the improvements seem solid then we might switch the default to your implementation.

@dxqbYD
Copy link
Author

dxqbYD commented Nov 1, 2024

@dxqbYD we checked it out and we think you're right that the implementation for beta > 0 is suboptimal, although it's almost certainly the one used by the original paper.

Would you be willing to open a PR adding a kwarg to enable the new behaviour? We can default to the old behaviour for backward compatibility for now, but if the improvements seem solid then we might switch the default to your implementation.

Thank you for taking the time to look at this!
I can open a PR, but what would you call this parameter, and when would you switch it to default behaviour?

I think it's going to be difficult to show that there are improvements. Someone would have to do paper-level research that some validation value improves, and even that is open to bias.

I'd honestly just fix it, otherwise you end up with a "temporary" kwarg that's really forever and nobody uses.
Or alternatively: whether the author intended it or not, this is what Adafactor is now, won't fix.

I'd make the argument for fixing it, because it is quite clear that the Adafactor authors wanted to propose "Adam plus", not change Adam.

@Rocketknight1
Copy link
Member

Rocketknight1 commented Nov 1, 2024

@dxqbYD The problem there is we're in slightly uncharted territory anyway! Even though your theoretical arguments are sound, we can't rule out that performance will be worse with the fix, because optimizers can be counterintuitive sometimes. At the very least we'd like some hints from the community that Adafactor is not noticeably worse with the change before we make it the default, and we do have a preference for backward compatibility unless there's a strong reason to break things.

I think if we add it as a kwarg like fixed_momentum or something, with a detailed explanation in the docstring (which will get added to the docs as well), the kind of people who are using Adafactor with non-default beta1 are likely to see it, and I suspect a lot of them will enable it. We can revisit this later to make it the default once we have any empirical data one way or the other, but I'm a little nervous about just silently changing it.

@dxqbYD
Copy link
Author

dxqbYD commented Nov 1, 2024

@dxqbYD The problem there is we're in slightly uncharted territory anyway! Even though your theoretical arguments are sound, we can't rule out that performance will be worse with the fix, because optimizers can be counterintuitive sometimes. At the very least we'd like some hints from the community that Adafactor is not noticeably worse with the change before we make it the default, and we do have a preference for backward compatibility unless there's a strong reason to break things.

I think if we add it as a kwarg like fixed_momentum or something, with a detailed explanation in the docstring (which will get added to the docs as well), the kind of people who are using Adafactor with non-default beta1 are likely to see it, and I suspect a lot of them will enable it. We can revisit this later to make it the default once we have any empirical data one way or the other, but I'm a little nervous about just silently changing it.

Ok. Will open a PR.
I'd ask someone else to do/check the doc, because English is not my first language.

@rwightman
Copy link
Contributor

rwightman commented Nov 1, 2024

I don't think changing the default behaviour is appropriate as this matches tensor2tensor and I believe also popular jax (optax / flax) impl. So would need a kwarg if implemented.

That said, I don't believe this is a bug. While not being addressed in the paper, it matches t2t impl and which was done with feedback from paper authors as per paper acknowledgements. Also, this behaviour matches RMSProp which also includes the rms scaled update in the momentum buffer when enabled. So it's not 'out of the blue' or an anomaly.

Also worth noting, was another proposed variant of Adafactor in Scaling Vision Transformers (https://arxiv.org/abs/2106.04560). There is an impl in big vision https://github.com/google-research/big_vision/blob/main/big_vision/optax.py

I'm sure the authors involved in the above variant were principled in their modifications and comparisons. They always have first moment enabled in this impl, though in half precision. They change things a bit in that LR is applied after the first moment (interestingly the rmsprop in TF vs PyTorch differed in this way too), but the momentum buffer here is also accumulating the updates after the rms scaling was applied, not separately as in the adam update rule. I think an impl of this 'BigVisionAdafactor' would be more interesting personally :) I might look at that sometime.

EDIT: I don't make the calls here so not my decision re to add support for the proposed changes or not, these are my opinions and comments :)

@dxqbYD
Copy link
Author

dxqbYD commented Nov 1, 2024

I don't think changing the default behaviour is appropriate as this matches tensor2tensor and I believe also popular jax (optax / flax) impl. So would need a kwarg if implemented.

changing the default behaviour is off the table because the maintainers want to be careful, which I understand.
But I'd still consider this a bug - the question is only when and by whom it was introduced it, and if we consider it "as part of the package" now because it might have been the original authors.

Yes, t2t also has it.
However, it does not seem that the google/big vision implementation has this bug. I've just looked at it and they seem to treat the two EMAs completely separate.

, but the momentum buffer here is also accumulating the updates after the rms scaling was applied,

you could argue whether RMS scaling should go into exp_avg or not, but this might have been intentional and not want I am argueing above.
What was likely not intentional is adding 1/sqrt of the second moment to the first moment. That's two different things.

RMS scaling is here:

update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))

What should not go into exp_avg is here:

update = exp_avg_sq.rsqrt().mul_(grad)

@rwightman
Copy link
Contributor

rwightman commented Nov 2, 2024

My brain is fried at the end of the week and tracing through optimizer code in diff frameworks is always fun, initially I though the big vision impl was doing this differently but after another pass only the place where LR scaling is done appears fundamentally diff....

When I said 'scale by rms' I was refering to optax naming for the factored scaling ('scale_by_factored_rms'), the second moment & related calculations. I was not refering to the 'clipping by block rms' operation or 'scale by block rms' there.

In both cases below the momentum buffer is doing a running average of the updates (grad * exp_avg_sq.rsqrt()). In big vision, the LR is applied in after momentum instead of before. Weight decay is also a bit different wrt scaling by LR.

The optax chain for big vision is:

scale_by_factored_rms
clip_by_block_rms
momentum  # first moment
scale_by_learning_rate  # this is above momentum in impl below
weight_decay
negate

The optax chain for the optax impl of adafactor, which I believe matches the transformers / tensor2tensor impl:

  tx = [
      factorized.scale_by_factored_rms(
          factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)]
  # This basic rescaling is typically combined with one or more of the following
  # transformation (all can be disabled via adafactor's constructor args).
  if clipping_threshold is not None:
    tx.append(clipping.clip_by_block_rms(clipping_threshold))
  if learning_rate is not None:
    tx.append(transform.scale_by_learning_rate(learning_rate, flip_sign=False))
  if multiply_by_parameter_scale:
    tx.append(transform.scale_by_param_block_rms())
  if momentum is not None:
    tx.append(
        transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum))
  if weight_decay_rate is not None:
    tx.append(transform.add_decayed_weights(
        weight_decay_rate, mask=weight_decay_mask))
  # In gradient "descent" we follow the negative gradient.
  tx.append(transform.scale(-1))

And again, going back to 'what should not go into exp_avg', rmsprop is doing something similar. Instead of maintaing grad^2 and grad moments separately and then applying them, it calculates grad^2 moment, and then calculates the update from this grad/sqrt(grad^2 buf) and then smooths this update if first moment is enabled.

https://github.com/pytorch/pytorch/blob/ee2f8a50d3527fad3845a7d67ab5467bfaf6c0fe/torch/optim/rmsprop.py#L309-L330

I think more than a 'it's not like adam' argument is needed to show it's actually a bug and not intentional. That and some strong empircal evidence to show the alternate formulation works better in some scenarios.

@rwightman
Copy link
Contributor

Reading with fresh eyes, I believe my comments made sense.

In the optax impl, after the 'scale_by_factored_rms' the subsequent ops in the chain are operating on the 'update' aka subtrahend aka value to be subtracted from the params. As with the tensor2tensor and impl here.

The proposed changes may certainly be better but I'd be looking for strong evidence of that before taking on the extra maintenance burden of an alternate code path. It's unlikely there was much testing of beta1 > 0 in the original, but we do know the big vision impl was tested for some large vit training, w/ beta1 > 0 and it worked well for them.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@dxqbYD
Copy link
Author

dxqbYD commented Nov 30, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

I'll submit that PR, just didn't have time yet

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 2, 2025
@SunMarc SunMarc reopened this Jan 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants